diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d8e7afd..4bba529 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,6 +34,9 @@ jobs: - name: Run checks (lint, vet, fmt-check, test) run: make check + # Note: Validation tests with real cloud providers run in separate workflows + # See .github/workflows/validation-*.yml for provider-specific validation tests + - name: Run security scan run: make security continue-on-error: true diff --git a/.github/workflows/validation-lambdalabs.yml b/.github/workflows/validation-lambdalabs.yml new file mode 100644 index 0000000..da7e731 --- /dev/null +++ b/.github/workflows/validation-lambdalabs.yml @@ -0,0 +1,58 @@ +name: LambdaLabs Validation Tests + +on: + schedule: + # Run daily at 2 AM UTC + - cron: '0 2 * * *' + workflow_dispatch: + # Allow manual triggering + pull_request: + paths: + - 'internal/lambdalabs/**' + - 'pkg/v1/**' + branches: [ main ] + +jobs: + lambdalabs-validation: + name: LambdaLabs Provider Validation + runs-on: ubuntu-latest + if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.23.0' + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Install dependencies + run: make deps + + - name: Run LambdaLabs validation tests + env: + LAMBDALABS_API_KEY: ${{ secrets.LAMBDALABS_API_KEY }} + TEST_PRIVATE_KEY_BASE64: ${{ secrets.TEST_PRIVATE_KEY_BASE64 }} + TEST_PUBLIC_KEY_BASE64: ${{ secrets.TEST_PUBLIC_KEY_BASE64 }} + VALIDATION_TEST: true + run: | + cd internal/lambdalabs + go test -v -short=false -timeout=30m ./... + + - name: Upload test results + uses: actions/upload-artifact@v4 + if: always() + with: + name: lambdalabs-validation-results + path: | + internal/lambdalabs/coverage.out diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..024618b --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.env +__debug_bin* diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..719c068 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,29 @@ +{ + "go.useLanguageServer": true, + "gopls": { + "gofumpt": true + }, + "[go]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.organizeImports": "explicit" + } + }, + "[go.mod]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.organizeImports": "explicit" + } + }, + "go.lintTool": "golangci-lint", + "go.lintFlags": [ + "--fast" + ], + "go.testFlags": [ + "-race", + "-v", + "-count=1", + "-timeout=30m" + ], + "go.testEnvFile": "${workspaceFolder}/.env" +} \ No newline at end of file diff --git a/Makefile b/Makefile index af26cfe..154128f 100644 --- a/Makefile +++ b/Makefile @@ -3,10 +3,16 @@ # Variables BINARY_NAME=compute -MODULE_NAME=github.com/brevdev/compute +MODULE_NAME=github.com/brevdev/cloud BUILD_DIR=build COVERAGE_DIR=coverage +# Load environment variables from .env file if it exists +ifneq (,$(wildcard .env)) + include .env + export +endif + # Go related variables GOCMD=go GOBUILD=$(GOCMD) build @@ -59,6 +65,18 @@ build-windows: .PHONY: test test: @echo "Running tests..." + $(GOTEST) -v -short ./... + +# Run validation tests +.PHONY: test-validation +test-validation: + @echo "Running validation tests..." + $(GOTEST) -v -short=false ./... + +# Run all tests including validation +.PHONY: test-all +test-all: + @echo "Running all tests..." $(GOTEST) -v ./... # Run tests with coverage @@ -181,7 +199,9 @@ help: @echo "Available targets:" @echo " build - Build the project" @echo " build-all - Build for Linux, macOS, and Windows" - @echo " test - Run tests" + @echo " test - Run tests (with -short flag)" + @echo " test-validation - Run validation tests (without -short flag)" + @echo " test-all - Run all tests including validation" @echo " test-coverage - Run tests with coverage report" @echo " test-race - Run tests with race detection" @echo " bench - Run benchmarks" @@ -197,4 +217,4 @@ help: @echo " docs - Generate documentation" @echo " check - Run all checks (lint, vet, fmt-check, test)" @echo " install-tools - Install development tools" - @echo " help - Show this help message" \ No newline at end of file + @echo " help - Show this help message" \ No newline at end of file diff --git a/docs/VALIDATION_TESTING.md b/docs/VALIDATION_TESTING.md new file mode 100644 index 0000000..95e0a4e --- /dev/null +++ b/docs/VALIDATION_TESTING.md @@ -0,0 +1,135 @@ +# Validation Testing + +This document describes the validation testing framework for cloud provider implementations. + +## Overview + +Validation tests verify that cloud provider implementations correctly implement the SDK interfaces by making real API calls to cloud providers. These tests are separate from unit tests and require actual cloud credentials. + +## Running Validation Tests + +### Locally + +```bash +# Skip validation tests (default) +make test + +# Run validation tests +make test-validation + +# Run all tests +make test-all +``` + +### Environment Variables + +Each provider requires specific environment variables: + +- **LambdaLabs**: `LAMBDALABS_API_KEY` + +### CI/CD + +Validation tests run automatically: +- Daily via scheduled workflows +- On pull requests when labeled with `run-validation` +- Manually via workflow dispatch + +## Adding New Providers + +1. Create validation test file: `internal/{provider}/v1/validation_test.go` +2. Use the shared validation package with provider-specific configuration: + +```go +func TestValidationFunctions(t *testing.T) { + apiKey := os.Getenv("YOUR_PROVIDER_API_KEY") + if apiKey == "" { + t.Skip("YOUR_PROVIDER_API_KEY not set, skipping YourProvider validation tests") + } + + config := validation.ProviderConfig{ + Credential: NewYourProviderCredential("validation-test", apiKey), + } + validation.RunValidationSuite(t, config) +} +``` + +3. Create CI workflow: `.github/workflows/validation-{provider}.yml` +4. Add environment variables to CI secrets +5. Update this documentation + +## Shared Validation Package + +The validation tests use a shared package at `internal/validation/` that provides: +- `RunValidationSuite()` - Tests all validation functions from pkg/v1/ +- `RunInstanceLifecycleValidation()` - Tests instance lifecycle operations +- `ProviderConfig` - Configuration for provider-specific setup using CloudCredential + +The `ProviderConfig` uses the existing `CloudCredential` interface which acts as a factory for `CloudClient` instances. The provider name is automatically obtained from the credential's `GetCloudProviderID()` method. This approach eliminates code duplication and ensures consistent validation testing across all providers while leveraging the existing credential abstraction. + +## Test Structure + +Validation tests use `testing.Short()` guards: + +```go +func TestValidation(t *testing.T) { + if testing.Short() { + t.Skip("Skipping validation tests in short mode") + } + // ... validation logic +} +``` + +This ensures validation tests only run when explicitly requested. + +## Validation Functions Tested + +The framework tests all validation functions from the SDK: + +### Instance Management +- `ValidateCreateInstance` - Tests instance creation with timing and attribute validation +- `ValidateListCreatedInstance` - Tests instance listing and filtering +- `ValidateTerminateInstance` - Tests instance termination +- `ValidateMergeInstanceForUpdate` - Tests instance update merging logic + +### Instance Types +- `ValidateGetInstanceTypes` - Tests instance type retrieval and filtering +- `ValidateRegionalInstanceTypes` - Tests regional instance type filtering +- `ValidateStableInstanceTypeIDs` - Tests instance type ID stability + +### Locations +- `ValidateGetLocations` - Tests location retrieval and availability + +## Security Considerations + +- Validation tests use real cloud credentials stored as GitHub secrets +- Tests create and destroy real cloud resources +- Proper cleanup is implemented to avoid resource leaks +- Tests are designed to be cost-effective and use minimal resources + +## Troubleshooting + +### Common Issues + +1. **Missing credentials**: Ensure environment variables are set +2. **Quota limits**: Tests may skip if quota is exceeded +3. **Resource availability**: Tests adapt to available instance types and locations +4. **Network timeouts**: Tests use appropriate timeouts for cloud operations + +### Debugging + +```bash +# Run specific validation test +go test -v -short=false -run TestValidationFunctions ./internal/lambdalabs/v1/ + +# Run with verbose output +go test -v -short=false -timeout=20m ./internal/lambdalabs/v1/ +``` + +## Contributing + +When adding new validation functions: + +1. Add the validation function to the appropriate `pkg/v1/*.go` file +2. Add corresponding test in `internal/{provider}/v1/validation_test.go` +3. Ensure proper cleanup and error handling +4. Update this documentation diff --git a/docs/example-dot-env b/docs/example-dot-env new file mode 100644 index 0000000..eb7fbbf --- /dev/null +++ b/docs/example-dot-env @@ -0,0 +1,3 @@ +LAMBDALABS_API_KEY=secret_my-api-key_********** +TEST_PRIVATE_KEY_BASE64=LS0tLS1CRUdJTiBSU0EgUFJJVk... +TEST_PUBLIC_KEY_BASE64=LS0tLS1CRUdJTiBQVUJMSUMgS0V... \ No newline at end of file diff --git a/go.mod b/go.mod index d6b93ec..cfaadd6 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.23.2 require ( github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b github.com/bojanz/currency v1.3.1 - github.com/brevdev/compute v0.0.0-20250805004716-bc4fe363e0ea + github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/jarcoal/httpmock v1.4.0 github.com/nebius/gosdk v0.0.0-20250731090238-d96c0d4a5930 @@ -16,19 +16,22 @@ require ( require ( buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.31.0-20231030212536-12f9cba37c9d.2 // indirect + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cockroachdb/apd/v3 v3.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gliderlabs/ssh v0.3.8 // indirect github.com/gofrs/flock v0.12.1 // indirect github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + golang.org/x/crypto v0.41.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect - golang.org/x/net v0.38.0 // indirect - golang.org/x/sync v0.12.0 // indirect - golang.org/x/sys v0.31.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/net v0.42.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.33.0 // indirect diff --git a/go.sum b/go.sum index 011147e..949962d 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,10 @@ buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.31.0-2023103021253 buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.31.0-20231030212536-12f9cba37c9d.2/go.mod h1:xafc+XIsTxTy76GJQ1TKgvJWsSugFBqMaN27WhUblew= github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b h1:mimo19zliBX/vSQ6PWWSL9lK8qwHozUj03+zLoEB8O0= github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b/go.mod h1:fvzegU4vN3H1qMT+8wDmzjAcDONcgo2/SZ/TyfdUOFs= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/bojanz/currency v1.3.1 h1:3BUAvy/5hU/Pzqg5nrQslVihV50QG+A2xKPoQw1RKH4= github.com/bojanz/currency v1.3.1/go.mod h1:jNoZiJyRTqoU5DFoa+n+9lputxPUDa8Fz8BdDrW06Go= -github.com/brevdev/compute v0.0.0-20250805004716-bc4fe363e0ea h1:U+mj2Q4lYMMkCuflMzFeIzf0tiASimf8/juGhcAT3DY= -github.com/brevdev/compute v0.0.0-20250805004716-bc4fe363e0ea/go.mod h1:rxhy3+lWmdnVABBys6l+Z+rVDeKa5nyy0avoQkTmTFw= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg= @@ -14,14 +14,16 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= +github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 h1:pRhl55Yx1eC7BZ1N+BBWwnKaMyD8uC+34TLdndZMAKk= @@ -52,16 +54,26 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 h1:NnYq6UN9ReLM9/Y01KWNOWyI5xQ9kbIms5GGJVwS/Yc= google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= diff --git a/internal/collections/collections.go b/internal/collections/collections.go new file mode 100644 index 0000000..e244816 --- /dev/null +++ b/internal/collections/collections.go @@ -0,0 +1,98 @@ +package collections + +import ( + "fmt" + + "github.com/cenkalti/backoff/v4" +) + +func Flatten[T any](listOfLists [][]T) []T { + result := []T{} + for _, list := range listOfLists { + result = append(result, list...) + } + return result +} + +func GroupBy[K comparable, A any](list []A, keyGetter func(A) K) map[K][]A { + result := map[K][]A{} + for _, item := range list { + key := keyGetter(item) + result[key] = append(result[key], item) + } + return result +} + +func MapE[T, R any](items []T, mapper func(T) (R, error)) ([]R, error) { + results := []R{} + for _, item := range items { + res, err := mapper(item) + if err != nil { + return results, err + } + results = append(results, res) + } + return results, nil +} + +func GetMapValues[K comparable, V any](m map[K]V) []V { + values := []V{} + for _, v := range m { + values = append(values, v) + } + return values +} + +// loops over list and returns when has returns true +func ListHas[K any](list []K, has func(l K) bool) bool { + k := Find(list, has) + if k != nil { + return true + } + return false +} + +func MapHasKey[K comparable, V any](m map[K]V, key K) bool { + _, ok := m[key] + return ok +} + +func ListContains[K comparable](list []K, item K) bool { + return ListHas(list, func(l K) bool { return l == item }) +} + +func Find[T any](list []T, f func(T) bool) *T { + for _, item := range list { + if f(item) { + return &item + } + } + return nil +} + +// returns those that are true +func Filter[T any](list []T, f func(T) bool) []T { + result := []T{} + for _, item := range list { + if f(item) { + result = append(result, item) + } + } + return result +} + +func Ptr[T any](x T) *T { + return &x +} + +func RetryWithDataAndAttemptCount[T any](o backoff.OperationWithData[T], b backoff.BackOff) (T, error) { + attemptCount := 0 + t, err := backoff.RetryWithData(func() (T, error) { + attemptCount++ + return o() + }, b) + if err != nil { + return t, fmt.Errorf("attemptCount %d: %w", attemptCount, err) + } + return t, nil +} diff --git a/internal/lambdalabs/CONTRIBUTE.md b/internal/lambdalabs/CONTRIBUTE.md index 0cb530f..6366b1a 100644 --- a/internal/lambdalabs/CONTRIBUTE.md +++ b/internal/lambdalabs/CONTRIBUTE.md @@ -1,6 +1,19 @@ +# Contributing to Lambda Labs +## Setup -## Prompts +- Create a `.env` file in the root of the project with the following: + +``` +LAMBDALABS_API_KEY=secret_my-api-key_********** +``` + +## Running Tests + +Use the vscode "run tests" task to run the tests. + + +## Useful Prompts ``` can you take a look at the file structure in @v1 this is supposed to be the reference/interface for providers. can you replicate the file structure in @lambdalabs ? I just want the file structure and maybe some stubs if they make sense. ``` diff --git a/internal/lambdalabs/v1/capabilities.go b/internal/lambdalabs/v1/capabilities.go index ba348ab..99e1105 100644 --- a/internal/lambdalabs/v1/capabilities.go +++ b/internal/lambdalabs/v1/capabilities.go @@ -3,7 +3,7 @@ package v1 import ( "context" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) // getLambdaLabsCapabilities returns the unified capabilities for Lambda Labs diff --git a/internal/lambdalabs/v1/capabilities_test.go b/internal/lambdalabs/v1/capabilities_test.go index a2bebb6..d76181f 100644 --- a/internal/lambdalabs/v1/capabilities_test.go +++ b/internal/lambdalabs/v1/capabilities_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func TestLambdaLabsClient_GetCapabilities(t *testing.T) { diff --git a/internal/lambdalabs/v1/client.go b/internal/lambdalabs/v1/client.go index 7f10897..75ae00e 100644 --- a/internal/lambdalabs/v1/client.go +++ b/internal/lambdalabs/v1/client.go @@ -5,9 +5,11 @@ import ( "crypto/sha256" "fmt" "net/http" + "time" openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" + "github.com/cenkalti/backoff/v4" ) // LambdaLabsCredential implements the CloudCredential interface for Lambda Labs @@ -36,9 +38,13 @@ func (c *LambdaLabsCredential) GetAPIType() v1.APIType { return v1.APITypeGlobal } +const CloudProviderID = "lambda-labs" + +const DefaultRegion string = "us-west-1" + // GetCloudProviderID returns the cloud provider ID for Lambda Labs func (c *LambdaLabsCredential) GetCloudProviderID() v1.CloudProviderID { - return "lambdalabs" + return CloudProviderID } // GetTenantID returns the tenant ID for Lambda Labs @@ -55,10 +61,11 @@ func (c *LambdaLabsCredential) MakeClient(_ context.Context, _ string) (v1.Cloud // It embeds NotImplCloudClient to handle unsupported features type LambdaLabsClient struct { v1.NotImplCloudClient - refID string - apiKey string - baseURL string - client *openapi.APIClient + refID string + apiKey string + baseURL string + client *openapi.APIClient + location string } var _ v1.CloudClient = &LambdaLabsClient{} @@ -88,8 +95,11 @@ func (c *LambdaLabsClient) GetCloudProviderID() v1.CloudProviderID { } // MakeClient creates a new client instance -func (c *LambdaLabsClient) MakeClient(_ context.Context, _ string) (v1.CloudClient, error) { - // Lambda Labs doesn't require location-specific clients +func (c *LambdaLabsClient) MakeClient(_ context.Context, location string) (v1.CloudClient, error) { + if location == "" { + location = DefaultRegion + } + c.location = location return c, nil } @@ -110,3 +120,10 @@ func (c *LambdaLabsClient) makeAuthContext(ctx context.Context) context.Context UserName: c.apiKey, }) } + +func getBackoff() backoff.BackOff { + bo := backoff.NewExponentialBackOff() + bo.InitialInterval = 1000 * time.Millisecond + bo.MaxElapsedTime = 120 * time.Second + return bo +} diff --git a/internal/lambdalabs/v1/client_test.go b/internal/lambdalabs/v1/client_test.go index 105232d..36a85ed 100644 --- a/internal/lambdalabs/v1/client_test.go +++ b/internal/lambdalabs/v1/client_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func TestLambdaLabsClient_GetAPIType(t *testing.T) { diff --git a/internal/lambdalabs/v1/credential_test.go b/internal/lambdalabs/v1/credential_test.go index abdb5b2..700358d 100644 --- a/internal/lambdalabs/v1/credential_test.go +++ b/internal/lambdalabs/v1/credential_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func TestLambdaLabsCredential_GetReferenceID(t *testing.T) { @@ -26,7 +26,7 @@ func TestLambdaLabsCredential_GetAPIType(t *testing.T) { func TestLambdaLabsCredential_GetCloudProviderID(t *testing.T) { cred := &LambdaLabsCredential{} - assert.Equal(t, v1.CloudProviderID("lambdalabs"), cred.GetCloudProviderID()) + assert.Equal(t, v1.CloudProviderID("lambda-labs"), cred.GetCloudProviderID()) } func TestLambdaLabsCredential_GetTenantID(t *testing.T) { diff --git a/internal/lambdalabs/v1/errors.go b/internal/lambdalabs/v1/errors.go new file mode 100644 index 0000000..10a47a6 --- /dev/null +++ b/internal/lambdalabs/v1/errors.go @@ -0,0 +1,51 @@ +package v1 + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + + openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + v1 "github.com/brevdev/cloud/pkg/v1" + "github.com/cenkalti/backoff/v4" +) + +func handleAPIError(_ context.Context, resp *http.Response, err error) error { + body := "" + e, ok := err.(openapi.GenericOpenAPIError) + if ok { + body = string(e.Body()) + } + if body == "" { + bodyBytes, errr := io.ReadAll(resp.Body) + if errr != nil { + fmt.Printf("Error reading response body: %v\n", errr) + } + body = string(bodyBytes) + } + outErr := fmt.Errorf("LambdaLabs API error\n%s\n%s:\nErr: %s\n%s", resp.Request.URL, resp.Status, err.Error(), body) + if strings.Contains(body, "instance does not exist") { //nolint:gocritic // ignore + return backoff.Permanent(v1.ErrInstanceNotFound) + } else if strings.Contains(body, "banned you temporarily") { + return outErr + } else if resp.StatusCode < 500 && resp.StatusCode != 429 { // 429 Too Many Requests (use back off) + return backoff.Permanent(outErr) + } else { + return outErr + } +} + +func handleErrToCloudErr(e error) error { + if e == nil { + return nil + } + if strings.Contains(e.Error(), "Not enough capacity") || strings.Contains(e.Error(), "insufficient-capacity") { //nolint:gocritic // ignore + return v1.ErrInsufficientResources + } else if strings.Contains(e.Error(), "global/invalid-parameters") && strings.Contains(e.Error(), "Region") && strings.Contains(e.Error(), "does not exist") { + return v1.ErrInsufficientResources + } else { + return e + } +} diff --git a/internal/lambdalabs/v1/helpers_test.go b/internal/lambdalabs/v1/helpers_test.go index d660122..cb7f01f 100644 --- a/internal/lambdalabs/v1/helpers_test.go +++ b/internal/lambdalabs/v1/helpers_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/assert" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func TestConvertLambdaLabsInstanceToV1Instance(t *testing.T) { diff --git a/internal/lambdalabs/v1/instance.go b/internal/lambdalabs/v1/instance.go index 9ecc05f..1e3efe8 100644 --- a/internal/lambdalabs/v1/instance.go +++ b/internal/lambdalabs/v1/instance.go @@ -2,12 +2,14 @@ package v1 import ( "context" + "errors" "fmt" "strings" "time" + "github.com/alecthomas/units" openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) const lambdaLabsTimeNameFormat = "2006-01-02-15-04-05Z07-00" @@ -25,19 +27,22 @@ func (c *LambdaLabsClient) CreateInstance(ctx context.Context, attrs v1.CreateIn Name: keyPairName, PublicKey: &attrs.PublicKey, } - - _, resp, err := c.client.DefaultAPI.AddSSHKey(c.makeAuthContext(ctx)).AddSSHKeyRequest(request).Execute() + keyPairResp, resp, err := c.client.DefaultAPI.AddSSHKey(c.makeAuthContext(ctx)).AddSSHKeyRequest(request).Execute() if resp != nil { defer func() { _ = resp.Body.Close() }() } if err != nil && !strings.Contains(err.Error(), "name must be unique") { return nil, fmt.Errorf("failed to add SSH key: %w", err) } + keyPairName = keyPairResp.Data.Name + } + if keyPairName == "" { + return nil, errors.New("keyPairName is required if public key not provided") } location := attrs.Location if location == "" { - location = "us-west-1" + location = c.location } quantity := int32(1) @@ -50,9 +55,10 @@ func (c *LambdaLabsClient) CreateInstance(ctx context.Context, attrs v1.CreateIn } name := fmt.Sprintf("%s--%s", c.GetReferenceID(), time.Now().UTC().Format(lambdaLabsTimeNameFormat)) - if attrs.Name != "" { - name = fmt.Sprintf("%s--%s--%s", c.GetReferenceID(), attrs.Name, time.Now().UTC().Format(lambdaLabsTimeNameFormat)) + if len(name) > 64 { + return nil, errors.New("name is too long") } + request.Name = *openapi.NewNullableString(&name) resp, httpResp, err := c.client.DefaultAPI.LaunchInstance(c.makeAuthContext(ctx)).LaunchInstanceRequest(request).Execute() @@ -60,7 +66,7 @@ func (c *LambdaLabsClient) CreateInstance(ctx context.Context, attrs v1.CreateIn defer func() { _ = httpResp.Body.Close() }() } if err != nil { - return nil, fmt.Errorf("failed to launch instance: %w", err) + return nil, fmt.Errorf("failed to launch instance: %w", handleErrToCloudErr(err)) } if len(resp.Data.InstanceIds) != 1 { @@ -142,58 +148,73 @@ func (c *LambdaLabsClient) RebootInstance(ctx context.Context, instanceID v1.Clo } // MergeInstanceForUpdate merges instance data for updates -func convertLambdaLabsInstanceToV1Instance(llInstance openapi.Instance) *v1.Instance { - var publicIP, privateIP, hostname, name string - - if llInstance.Ip.IsSet() { - publicIP = *llInstance.Ip.Get() +func convertLambdaLabsInstanceToV1Instance(instance openapi.Instance) *v1.Instance { + var instanceIP string + if instance.Ip.IsSet() { + instanceIP = *instance.Ip.Get() } - if llInstance.PrivateIp.IsSet() { - privateIP = *llInstance.PrivateIp.Get() - } - if llInstance.Hostname.IsSet() { - hostname = *llInstance.Hostname.Get() + + var instanceName string + if instance.Name.IsSet() { + instanceName = *instance.Name.Get() } - if llInstance.Name.IsSet() { - name = *llInstance.Name.Get() + + var instanceHostname string + if instance.Hostname.IsSet() { + instanceHostname = *instance.Hostname.Get() } + nameSplit := strings.Split(instanceName, "--") var cloudCredRefID string - var createdAt time.Time - if name != "" { - parts := strings.Split(name, "--") - if len(parts) > 0 { - cloudCredRefID = parts[0] - } - if len(parts) > 1 { - createdAt, _ = time.Parse("2006-01-02-15-04-05Z07-00", parts[1]) - } + if len(nameSplit) > 0 { + cloudCredRefID = nameSplit[0] + } + var createTime time.Time + if len(nameSplit) > 1 { + createTimeStr := nameSplit[1] + createTime, _ = time.Parse(lambdaLabsTimeNameFormat, createTimeStr) } - refID := "" - if len(llInstance.SshKeyNames) > 0 { - refID = llInstance.SshKeyNames[0] + var instancePrivateIP string + if instance.PrivateIp.IsSet() { + instancePrivateIP = *instance.PrivateIp.Get() } - return &v1.Instance{ - Name: name, - RefID: refID, + inst := v1.Instance{ + RefID: instance.SshKeyNames[0], CloudCredRefID: cloudCredRefID, - CreatedAt: createdAt, - CloudID: v1.CloudProviderInstanceID(llInstance.Id), - PublicIP: publicIP, - PrivateIP: privateIP, - PublicDNS: publicIP, - Hostname: hostname, - InstanceType: llInstance.InstanceType.Name, + CreatedAt: createTime, + CloudID: v1.CloudProviderInstanceID(instance.Id), + Name: instanceName, + PublicIP: instanceIP, + PublicDNS: instanceIP, + PrivateIP: instancePrivateIP, + Hostname: instanceHostname, Status: v1.Status{ - LifecycleStatus: convertLambdaLabsStatusToV1Status(llInstance.Status), + LifecycleStatus: convertLambdaLabsStatusToV1Status(instance.Status), + }, + InstanceType: instance.InstanceType.Name, + VolumeType: "ssd", + DiskSize: units.GiB * units.Base2Bytes(instance.InstanceType.Specs.StorageGib), + FirewallRules: v1.FirewallRules{ + IngressRules: []v1.FirewallRule{generateFirewallRouteFromPort(22), generateFirewallRouteFromPort(2222)}, // TODO pull from api + EgressRules: []v1.FirewallRule{generateFirewallRouteFromPort(22), generateFirewallRouteFromPort(2222)}, // TODO pull from api }, - Location: llInstance.Region.Name, SSHUser: "ubuntu", SSHPort: 22, Stoppable: false, Rebootable: true, + Location: instance.Region.Name, + } + inst.InstanceTypeID = v1.MakeGenericInstanceTypeIDFromInstance(inst) + return &inst +} + +func generateFirewallRouteFromPort(port int32) v1.FirewallRule { + return v1.FirewallRule{ + FromPort: port, + ToPort: port, + IPRanges: []string{"0.0.0.0/0"}, } } diff --git a/internal/lambdalabs/v1/instance_test.go b/internal/lambdalabs/v1/instance_test.go index aef19e7..87d388b 100644 --- a/internal/lambdalabs/v1/instance_test.go +++ b/internal/lambdalabs/v1/instance_test.go @@ -9,8 +9,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/brevdev/cloud/internal/collections" openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func TestLambdaLabsClient_CreateInstance_Success(t *testing.T) { @@ -79,6 +80,7 @@ func TestLambdaLabsClient_CreateInstance_WithoutPublicKey(t *testing.T) { InstanceType: "gpu_1x_a10", Location: "us-west-1", Name: "test-instance", + KeyPairName: collections.Ptr("test-key-pair"), } instance, err := client.CreateInstance(context.Background(), args) diff --git a/internal/lambdalabs/v1/instancetype.go b/internal/lambdalabs/v1/instancetype.go index 09f788d..4493f36 100644 --- a/internal/lambdalabs/v1/instancetype.go +++ b/internal/lambdalabs/v1/instancetype.go @@ -2,7 +2,6 @@ package v1 import ( "context" - "encoding/json" "fmt" "regexp" "strconv" @@ -11,181 +10,194 @@ import ( "github.com/alecthomas/units" "github.com/bojanz/currency" + "github.com/brevdev/cloud/internal/collections" openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) -// GetInstanceTypes retrieves available instance types from Lambda Labs -// Supported via: GET /api/v1/instance-types +// GetInstanceTypePollTime returns the polling interval for instance types +func (c *LambdaLabsClient) GetInstanceTypePollTime() time.Duration { + return 5 * time.Minute +} + func (c *LambdaLabsClient) GetInstanceTypes(ctx context.Context, args v1.GetInstanceTypeArgs) ([]v1.InstanceType, error) { - resp, httpResp, err := c.client.DefaultAPI.InstanceTypes(c.makeAuthContext(ctx)).Execute() - if httpResp != nil { - defer func() { _ = httpResp.Body.Close() }() + instanceTypesResp, err := c.getInstanceTypes(ctx) + if err != nil { + return nil, err } + + locations, err := c.GetLocations(ctx, v1.GetLocationsArgs{}) if err != nil { - return nil, fmt.Errorf("failed to get instance types: %w", err) - } - - var instanceTypes []v1.InstanceType - for _, llInstanceTypeData := range resp.Data { - for _, region := range llInstanceTypeData.RegionsWithCapacityAvailable { - instanceType, err := convertLambdaLabsInstanceTypeToV1InstanceType( - region.Name, - llInstanceTypeData.InstanceType, - true, - ) - if err != nil { - return nil, fmt.Errorf("failed to convert instance type: %w", err) + return nil, err + } + + instanceTypes, err := collections.MapE(collections.GetMapValues(instanceTypesResp.Data), func(resp openapi.InstanceTypes200ResponseDataValue) ([]v1.InstanceType, error) { + currentlyAvailableRegions := collections.GroupBy(resp.RegionsWithCapacityAvailable, func(lambdaRegion openapi.Region) string { + return lambdaRegion.Name + }) + its, err1 := collections.MapE(locations, func(region v1.Location) (v1.InstanceType, error) { + isAvailable := false + if _, ok := currentlyAvailableRegions[region.Name]; ok { + isAvailable = true + } + it, err2 := convertLambdaLabsInstanceTypeToV1InstanceType(region.Name, resp.InstanceType, isAvailable) + if err2 != nil { + return v1.InstanceType{}, err2 } - instanceTypes = append(instanceTypes, instanceType) + return it, nil + }) + if err1 != nil { + return []v1.InstanceType{}, err1 } + return its, nil + }) + if err != nil { + return nil, err } + instanceTypesFlattened := collections.Flatten(instanceTypes) - if len(args.Locations) > 0 && !args.Locations.IsAll() { - filtered := make([]v1.InstanceType, 0) - for _, it := range instanceTypes { - for _, loc := range args.Locations { - if it.Location == loc { - filtered = append(filtered, it) - break - } - } + if len(args.Locations) == 0 { + if c.location != "" { + args.Locations = []string{c.location} + } else { + args.Locations = v1.All } - instanceTypes = filtered } - if len(args.InstanceTypes) > 0 { - filtered := make([]v1.InstanceType, 0) - for _, it := range instanceTypes { - for _, itName := range args.InstanceTypes { - if it.Type == itName { - filtered = append(filtered, it) - break + if !args.Locations.IsAll() { + instanceTypesFlattened = collections.Filter(instanceTypesFlattened, func(it v1.InstanceType) bool { + return collections.ListContains(args.Locations, it.Location) + }) + } + + if len(args.SupportedArchitectures) > 0 { + instanceTypesFlattened = collections.Filter(instanceTypesFlattened, func(instanceType v1.InstanceType) bool { + for _, arch := range args.SupportedArchitectures { + if collections.ListContains(instanceType.SupportedArchitectures, arch) { + return true } } - } - instanceTypes = filtered + return false + }) + } + if len(args.InstanceTypes) > 0 { + instanceTypesFlattened = collections.Filter(instanceTypesFlattened, func(instanceType v1.InstanceType) bool { + return collections.ListContains(args.InstanceTypes, instanceType.Type) + }) } - return instanceTypes, nil + return instanceTypesFlattened, nil } -// GetInstanceTypePollTime returns the polling interval for instance types -func (c *LambdaLabsClient) GetInstanceTypePollTime() time.Duration { - // TODO: Configure appropriate polling time for Lambda Labs - return 5 * time.Minute -} - -// GetLocations retrieves available locations from Lambda Labs -// UNSUPPORTED: No location listing endpoints found in Lambda Labs API -func convertLambdaLabsInstanceTypeToV1InstanceType(location string, llInstanceType openapi.InstanceType, isAvailable bool) (v1.InstanceType, error) { - var gpus []v1.GPU - if !strings.Contains(llInstanceType.Description, "CPU") { - gpu := parseGPUFromDescription(llInstanceType.Description) - gpus = append(gpus, gpu) - } - - amount, err := currency.NewAmountFromInt64(int64(llInstanceType.PriceCentsPerHour), "USD") +func (c *LambdaLabsClient) getInstanceTypes(ctx context.Context) (*openapi.InstanceTypes200Response, error) { + ilr, err := collections.RetryWithDataAndAttemptCount(func() (*openapi.InstanceTypes200Response, error) { + res, resp, err := c.client.DefaultAPI.InstanceTypes(c.makeAuthContext(ctx)).Execute() + if resp != nil { + defer resp.Body.Close() //nolint:errcheck // ignore because using defer (for some reason HandleErrDefer) + } + if err != nil { + return &openapi.InstanceTypes200Response{}, handleAPIError(ctx, resp, err) + } + return res, nil + }, getBackoff()) if err != nil { - return v1.InstanceType{}, fmt.Errorf("failed to create price amount: %w", err) + return nil, err } - instanceType := v1.InstanceType{ - Location: location, - Type: llInstanceType.Name, - SupportedGPUs: gpus, - SupportedStorage: []v1.Storage{ - { - Type: "ssd", - Size: units.GiB * units.Base2Bytes(llInstanceType.Specs.StorageGib), - }, - }, - Memory: units.GiB * units.Base2Bytes(llInstanceType.Specs.MemoryGib), - VCPU: llInstanceType.Specs.Vcpus, - SupportedArchitectures: []string{"x86_64"}, - Stoppable: false, - Rebootable: true, - IsAvailable: isAvailable, - BasePrice: &amount, - Provider: "lambdalabs", - } - - instanceType.ID = v1.InstanceTypeID(fmt.Sprintf("lambdalabs-%s-%s", location, llInstanceType.Name)) - - return instanceType, nil + return ilr, nil } -func parseGPUFromDescription(description string) v1.GPU { - countRegex := regexp.MustCompile(`(\d+)x`) - memoryRegex := regexp.MustCompile(`(\d+) GB`) - nameRegex := regexp.MustCompile(`x (.*?) \(`) - +func parseGPUFromDescription(input string) (v1.GPU, error) { var gpu v1.GPU - if matches := countRegex.FindStringSubmatch(description); len(matches) > 1 { - if count, err := strconv.ParseInt(matches[1], 10, 32); err == nil { - gpu.Count = int32(count) - } + // Extract the count + countRegex := regexp.MustCompile(`(\d+)x`) + countMatch := countRegex.FindStringSubmatch(input) + if len(countMatch) == 0 { + return v1.GPU{}, fmt.Errorf("could not find count in %s", input) } + count, _ := strconv.ParseInt(countMatch[1], 10, 32) + gpu.Count = int32(count) - if matches := memoryRegex.FindStringSubmatch(description); len(matches) > 1 { - if memory, err := strconv.Atoi(matches[1]); err == nil { - gpu.Memory = units.GiB * units.Base2Bytes(memory) - } + // Extract the memory + memoryRegex := regexp.MustCompile(`(\d+) GB`) + memoryMatch := memoryRegex.FindStringSubmatch(input) + if len(memoryMatch) == 0 { + return v1.GPU{}, fmt.Errorf("could not find memory in %s", input) + } + memoryStr := memoryMatch[1] + memoryGiB, _ := strconv.Atoi(memoryStr) + gpu.Memory = units.GiB * units.Base2Bytes(memoryGiB) + + // Extract the network details + networkRegex := regexp.MustCompile(`(\w+\s?)+\)`) + networkMatch := networkRegex.FindStringSubmatch(input) + if len(networkMatch) == 0 { + return v1.GPU{}, fmt.Errorf("could not find network details in %s", input) } + networkStr := strings.TrimSuffix(networkMatch[0], ")") + networkDetails := strings.TrimSpace(strings.ReplaceAll(networkStr, memoryStr+" GB", "")) + gpu.NetworkDetails = networkDetails - if matches := nameRegex.FindStringSubmatch(description); len(matches) > 1 { - gpu.Name = strings.TrimSpace(matches[1]) - gpu.Type = gpu.Name + // Extract the name + nameRegex := regexp.MustCompile(`x (.*?) \(`) + nameMatch := nameRegex.FindStringSubmatch(input) + if len(nameMatch) == 0 { + return v1.GPU{}, fmt.Errorf("could not find name in %s", input) + } + nameStr := strings.TrimRight(strings.TrimLeft(nameMatch[0], "x "), " (") + nameStr = regexp.MustCompile(`(?i)^Tesla\s+`).ReplaceAllString(nameStr, "") + gpu.Name = nameStr + if networkDetails != "" { + gpu.Type = nameStr + "." + networkDetails + } else { + gpu.Type = nameStr } gpu.Manufacturer = "NVIDIA" - return gpu + return gpu, nil } -const lambdaLocationsData = `[ - {"location_name": "us-west-1", "description": "California, USA", "country": "USA"}, - {"location_name": "us-west-2", "description": "Arizona, USA", "country": "USA"}, - {"location_name": "us-west-3", "description": "Utah, USA", "country": "USA"}, - {"location_name": "us-south-1", "description": "Texas, USA", "country": "USA"}, - {"location_name": "us-east-1", "description": "Virginia, USA", "country": "USA"}, - {"location_name": "us-midwest-1", "description": "Illinois, USA", "country": "USA"}, - {"location_name": "australia-southeast-1", "description": "Australia", "country": "AUS"}, - {"location_name": "europe-central-1", "description": "Germany", "country": "DEU"}, - {"location_name": "asia-south-1", "description": "India", "country": "IND"}, - {"location_name": "me-west-1", "description": "Israel", "country": "ISR"}, - {"location_name": "europe-south-1", "description": "Italy", "country": "ITA"}, - {"location_name": "asia-northeast-1", "description": "Osaka, Japan", "country": "JPN"}, - {"location_name": "asia-northeast-2", "description": "Tokyo, Japan", "country": "JPN"}, - {"location_name": "us-east-3", "description": "Washington D.C, USA", "country": "USA"}, - {"location_name": "us-east-2", "description": "Washington D.C, USA", "country": "USA"}, - {"location_name": "australia-east-1", "description": "Sydney, Australia", "country": "AUS"}, - {"location_name": "us-south-3", "description": "Central Texas, USA", "country": "USA"}, - {"location_name": "us-south-2", "description": "North Texas, USA", "country": "USA"} -]` - -type LambdaLocation struct { - LocationName string `json:"location_name"` - Description string `json:"description"` - Country string `json:"country"` -} - -func (c *LambdaLabsClient) GetLocations(_ context.Context, _ v1.GetLocationsArgs) ([]v1.Location, error) { - var regionData []LambdaLocation - if err := json.Unmarshal([]byte(lambdaLocationsData), ®ionData); err != nil { - return nil, fmt.Errorf("failed to parse location data: %w", err) +func convertLambdaLabsInstanceTypeToV1InstanceType(location string, instType openapi.InstanceType, isAvailable bool) (v1.InstanceType, error) { + gpus := []v1.GPU{} + if !strings.Contains(instType.Description, "CPU") { + gpu, err := parseGPUFromDescription(instType.Description) + if err != nil { + return v1.InstanceType{}, err + } + gpus = append(gpus, gpu) } - - locations := make([]v1.Location, 0, len(regionData)) - for _, region := range regionData { - locations = append(locations, v1.Location{ - Name: region.LocationName, - Description: region.Description, - Available: true, - Country: region.Country, - }) + amount, err := currency.NewAmountFromInt64(int64(instType.PriceCentsPerHour), "USD") + if err != nil { + return v1.InstanceType{}, err } - - return locations, nil + it := v1.InstanceType{ + Location: location, + Type: instType.Name, + SupportedGPUs: gpus, + SupportedStorage: []v1.Storage{ + { + Type: "ssd", + Count: 1, + Size: units.GiB * units.Base2Bytes(instType.Specs.StorageGib), + }, + }, + SupportedUsageClasses: []string{"on-demand"}, + Memory: units.GiB * units.Base2Bytes(instType.Specs.MemoryGib), + MaximumNetworkInterfaces: 0, + NetworkPerformance: "", + SupportedNumCores: []int32{}, + DefaultCores: 0, + VCPU: instType.Specs.Vcpus, + SupportedArchitectures: []string{"x86_64"}, + ClockSpeedInGhz: 0, + Stoppable: false, + Rebootable: true, + IsAvailable: isAvailable, + BasePrice: &amount, + Provider: string(CloudProviderID), + } + it.ID = v1.MakeGenericInstanceTypeID(it) + return it, nil } diff --git a/internal/lambdalabs/v1/instancetype_test.go b/internal/lambdalabs/v1/instancetype_test.go index 13fc98f..483eb07 100644 --- a/internal/lambdalabs/v1/instancetype_test.go +++ b/internal/lambdalabs/v1/instancetype_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func TestLambdaLabsClient_GetInstanceTypes_Success(t *testing.T) { @@ -24,7 +24,9 @@ func TestLambdaLabsClient_GetInstanceTypes_Success(t *testing.T) { instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{}) require.NoError(t, err) - assert.Len(t, instanceTypes, 3) + locations, err := getLambdaLabsLocations() + require.NoError(t, err) + assert.Len(t, instanceTypes, len(locations)*2) a10Type := findInstanceTypeByName(instanceTypes, "gpu_1x_a10") require.NotNil(t, a10Type) @@ -33,7 +35,7 @@ func TestLambdaLabsClient_GetInstanceTypes_Success(t *testing.T) { assert.Len(t, a10Type.SupportedGPUs, 1) assert.Equal(t, int32(1), a10Type.SupportedGPUs[0].Count) assert.Equal(t, "NVIDIA", a10Type.SupportedGPUs[0].Manufacturer) - assert.Equal(t, "NVIDIA A10", a10Type.SupportedGPUs[0].Name) + assert.Equal(t, "A10", a10Type.SupportedGPUs[0].Name) } func TestLambdaLabsClient_GetInstanceTypes_FilterByLocation(t *testing.T) { @@ -48,7 +50,7 @@ func TestLambdaLabsClient_GetInstanceTypes_FilterByLocation(t *testing.T) { Locations: v1.LocationsFilter{"us-west-1"}, }) require.NoError(t, err) - assert.Len(t, instanceTypes, 1) + assert.Len(t, instanceTypes, 2) for _, instanceType := range instanceTypes { assert.Equal(t, "us-west-1", instanceType.Location) @@ -67,7 +69,9 @@ func TestLambdaLabsClient_GetInstanceTypes_FilterByInstanceType(t *testing.T) { InstanceTypes: []string{"gpu_1x_a10"}, }) require.NoError(t, err) - assert.Len(t, instanceTypes, 2) + locations, err := getLambdaLabsLocations() + require.NoError(t, err) + assert.Len(t, instanceTypes, len(locations)) for _, instanceType := range instanceTypes { assert.Equal(t, "gpu_1x_a10", instanceType.Type) @@ -165,40 +169,34 @@ func TestParseGPUFromDescription(t *testing.T) { expected v1.GPU }{ { - description: "1x NVIDIA A10 (24 GB)", - expected: v1.GPU{ - Count: 1, - Manufacturer: "NVIDIA", - Name: "NVIDIA A10", - Type: "NVIDIA A10", - Memory: 24 * 1024 * 1024 * 1024, - }, - }, - { - description: "8x NVIDIA H100 (80 GB)", + description: "1x H100 (80 GB SXM5)", expected: v1.GPU{ - Count: 8, - Manufacturer: "NVIDIA", - Name: "NVIDIA H100", - Type: "NVIDIA H100", - Memory: 80 * 1024 * 1024 * 1024, + Count: 1, + Manufacturer: "NVIDIA", + Name: "H100", + Type: "H100.SXM5", + Memory: 80 * 1024 * 1024 * 1024, + NetworkDetails: "80 GB SXM5", + MemoryDetails: "80 GB", }, }, { - description: "4x NVIDIA RTX 4090 (24 GB)", + description: "8x Tesla V100 (16 GB)", expected: v1.GPU{ - Count: 4, - Manufacturer: "NVIDIA", - Name: "NVIDIA RTX 4090", - Type: "NVIDIA RTX 4090", - Memory: 24 * 1024 * 1024 * 1024, + Count: 8, + Manufacturer: "NVIDIA", + Name: "V100", + Type: "V100", + Memory: 16 * 1024 * 1024 * 1024, + MemoryDetails: "16 GB", }, }, } for _, tt := range tests { t.Run(tt.description, func(t *testing.T) { - gpu := parseGPUFromDescription(tt.description) + gpu, err := parseGPUFromDescription(tt.description) + require.NoError(t, err) assert.Equal(t, tt.expected.Count, gpu.Count) assert.Equal(t, tt.expected.Manufacturer, gpu.Manufacturer) assert.Equal(t, tt.expected.Name, gpu.Name) @@ -212,14 +210,14 @@ func createMockInstanceTypeResponse() openapi.InstanceTypes200Response { return openapi.InstanceTypes200Response{ Data: map[string]openapi.InstanceTypes200ResponseDataValue{ "gpu_1x_a10": { - InstanceType: createMockLambdaLabsInstanceType("gpu_1x_a10", "1x NVIDIA A10 (24 GB)", "NVIDIA A10", 100), + InstanceType: createMockLambdaLabsInstanceType("gpu_1x_a10", "1x A10 (24 GB)", "A10", 100), RegionsWithCapacityAvailable: []openapi.Region{ createMockRegion("us-west-1", "US West 1"), createMockRegion("us-east-1", "US East 1"), }, }, "gpu_8x_h100": { - InstanceType: createMockLambdaLabsInstanceType("gpu_8x_h100", "8x NVIDIA H100 (80 GB)", "NVIDIA H100", 3200), + InstanceType: createMockLambdaLabsInstanceType("gpu_8x_h100", "8x H100 (80 GB SXM5)", "H100", 3200), RegionsWithCapacityAvailable: []openapi.Region{ createMockRegion("us-east-1", "US East 1"), }, diff --git a/internal/lambdalabs/v1/location.go b/internal/lambdalabs/v1/location.go new file mode 100644 index 0000000..2283407 --- /dev/null +++ b/internal/lambdalabs/v1/location.go @@ -0,0 +1,63 @@ +package v1 + +import ( + "context" + "encoding/json" + "fmt" + + v1 "github.com/brevdev/cloud/pkg/v1" +) + +const lambdaLocationsData = `[ + {"location_name": "us-west-1", "description": "California, USA", "country": "USA"}, + {"location_name": "us-west-2", "description": "Arizona, USA", "country": "USA"}, + {"location_name": "us-west-3", "description": "Utah, USA", "country": "USA"}, + {"location_name": "us-south-1", "description": "Texas, USA", "country": "USA"}, + {"location_name": "us-east-1", "description": "Virginia, USA", "country": "USA"}, + {"location_name": "us-midwest-1", "description": "Illinois, USA", "country": "USA"}, + {"location_name": "australia-southeast-1", "description": "Australia", "country": "AUS"}, + {"location_name": "europe-central-1", "description": "Germany", "country": "DEU"}, + {"location_name": "asia-south-1", "description": "India", "country": "IND"}, + {"location_name": "me-west-1", "description": "Israel", "country": "ISR"}, + {"location_name": "europe-south-1", "description": "Italy", "country": "ITA"}, + {"location_name": "asia-northeast-1", "description": "Osaka, Japan", "country": "JPN"}, + {"location_name": "asia-northeast-2", "description": "Tokyo, Japan", "country": "JPN"}, + {"location_name": "us-east-3", "description": "Washington D.C, USA", "country": "USA"}, + {"location_name": "us-east-2", "description": "Washington D.C, USA", "country": "USA"}, + {"location_name": "australia-east-1", "description": "Sydney, Australia", "country": "AUS"}, + {"location_name": "us-south-3", "description": "Central Texas, USA", "country": "USA"}, + {"location_name": "us-south-2", "description": "North Texas, USA", "country": "USA"} +]` + +type LambdaLocation struct { + LocationName string `json:"location_name"` + Description string `json:"description"` + Country string `json:"country"` +} + +func getLambdaLabsLocations() ([]LambdaLocation, error) { + var locationData []LambdaLocation + if err := json.Unmarshal([]byte(lambdaLocationsData), &locationData); err != nil { + return nil, fmt.Errorf("failed to parse location data: %w", err) + } + return locationData, nil +} + +func (c *LambdaLabsClient) GetLocations(_ context.Context, _ v1.GetLocationsArgs) ([]v1.Location, error) { + locationData, err := getLambdaLabsLocations() + if err != nil { + return nil, err + } + + locations := make([]v1.Location, 0, len(locationData)) + for _, location := range locationData { + locations = append(locations, v1.Location{ + Name: location.LocationName, + Description: location.Description, + Available: true, + Country: location.Country, + }) + } + + return locations, nil +} diff --git a/internal/lambdalabs/v1/location_test.go b/internal/lambdalabs/v1/location_test.go new file mode 100644 index 0000000..b7b1f99 --- /dev/null +++ b/internal/lambdalabs/v1/location_test.go @@ -0,0 +1 @@ +package v1 diff --git a/internal/lambdalabs/v1/validation_test.go b/internal/lambdalabs/v1/validation_test.go new file mode 100644 index 0000000..54b8b5b --- /dev/null +++ b/internal/lambdalabs/v1/validation_test.go @@ -0,0 +1,46 @@ +package v1 + +import ( + "os" + "testing" + + "github.com/brevdev/cloud/internal/validation" + v1 "github.com/brevdev/cloud/pkg/v1" +) + +func TestValidationFunctions(t *testing.T) { + checkSkip(t) + apiKey := getAPIKey() + + config := validation.ProviderConfig{ + Credential: NewLambdaLabsCredential("validation-test", apiKey), + StableIDs: []v1.InstanceTypeID{"us-west-1-noSub-gpu_8x_a100_80gb_sxm4", "us-east-1-noSub-gpu_8x_a100_80gb_sxm4"}, + } + + validation.RunValidationSuite(t, config) +} + +func TestInstanceLifecycleValidation(t *testing.T) { + checkSkip(t) + apiKey := getAPIKey() + + config := validation.ProviderConfig{ + Credential: NewLambdaLabsCredential("validation-test", apiKey), + } + + validation.RunInstanceLifecycleValidation(t, config) +} + +func checkSkip(t *testing.T) { + apiKey := getAPIKey() + isValidationTest := os.Getenv("VALIDATION_TEST") + if apiKey == "" && isValidationTest != "" { + t.Fatal("LAMBDALABS_API_KEY not set, but VALIDATION_TEST is set") + } else if apiKey == "" && isValidationTest == "" { + t.Skip("LAMBDALABS_API_KEY not set, skipping LambdaLabs validation tests") + } +} + +func getAPIKey() string { + return os.Getenv("LAMBDALABS_API_KEY") +} diff --git a/internal/nebius/v1/capabilities.go b/internal/nebius/v1/capabilities.go index 8d449c9..da76b41 100644 --- a/internal/nebius/v1/capabilities.go +++ b/internal/nebius/v1/capabilities.go @@ -3,7 +3,7 @@ package v1 import ( "context" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func getNebiusCapabilities() v1.Capabilities { diff --git a/internal/nebius/v1/client.go b/internal/nebius/v1/client.go index 9166671..25d6d20 100644 --- a/internal/nebius/v1/client.go +++ b/internal/nebius/v1/client.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" "github.com/nebius/gosdk" ) diff --git a/internal/nebius/v1/image.go b/internal/nebius/v1/image.go index 3d46386..f3540bf 100644 --- a/internal/nebius/v1/image.go +++ b/internal/nebius/v1/image.go @@ -3,7 +3,7 @@ package v1 import ( "context" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func (c *NebiusClient) GetImages(_ context.Context, _ v1.GetImageArgs) ([]v1.Image, error) { diff --git a/internal/nebius/v1/instance.go b/internal/nebius/v1/instance.go index d1a4688..7fc2824 100644 --- a/internal/nebius/v1/instance.go +++ b/internal/nebius/v1/instance.go @@ -3,7 +3,7 @@ package v1 import ( "context" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func (c *NebiusClient) CreateInstance(_ context.Context, _ v1.CreateInstanceAttrs) (*v1.Instance, error) { diff --git a/internal/nebius/v1/instancetype.go b/internal/nebius/v1/instancetype.go index 509b670..80b6e83 100644 --- a/internal/nebius/v1/instancetype.go +++ b/internal/nebius/v1/instancetype.go @@ -4,7 +4,7 @@ import ( "context" "time" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func (c *NebiusClient) GetInstanceTypes(_ context.Context, _ v1.GetInstanceTypeArgs) ([]v1.InstanceType, error) { diff --git a/internal/nebius/v1/location.go b/internal/nebius/v1/location.go index 8a1c17b..3491e0b 100644 --- a/internal/nebius/v1/location.go +++ b/internal/nebius/v1/location.go @@ -3,7 +3,7 @@ package v1 import ( "context" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func (c *NebiusClient) GetLocations(_ context.Context, _ v1.GetLocationsArgs) ([]v1.Location, error) { diff --git a/internal/nebius/v1/networking.go b/internal/nebius/v1/networking.go index f912b74..d31c39f 100644 --- a/internal/nebius/v1/networking.go +++ b/internal/nebius/v1/networking.go @@ -3,7 +3,7 @@ package v1 import ( "context" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func (c *NebiusClient) AddFirewallRulesToInstance(_ context.Context, _ v1.AddFirewallRulesToInstanceArgs) error { diff --git a/internal/nebius/v1/quota.go b/internal/nebius/v1/quota.go index 9920f67..40601ed 100644 --- a/internal/nebius/v1/quota.go +++ b/internal/nebius/v1/quota.go @@ -3,7 +3,7 @@ package v1 import ( "context" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func (c *NebiusClient) GetInstanceTypeQuotas(_ context.Context, _ v1.GetInstanceTypeQuotasArgs) (v1.Quota, error) { diff --git a/internal/nebius/v1/storage.go b/internal/nebius/v1/storage.go index 3d71df8..b642e18 100644 --- a/internal/nebius/v1/storage.go +++ b/internal/nebius/v1/storage.go @@ -3,7 +3,7 @@ package v1 import ( "context" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func (c *NebiusClient) ResizeInstanceVolume(_ context.Context, _ v1.ResizeInstanceVolumeArgs) error { diff --git a/internal/nebius/v1/tags.go b/internal/nebius/v1/tags.go index e186b1b..d79bae1 100644 --- a/internal/nebius/v1/tags.go +++ b/internal/nebius/v1/tags.go @@ -3,7 +3,7 @@ package v1 import ( "context" - v1 "github.com/brevdev/compute/pkg/v1" + v1 "github.com/brevdev/cloud/pkg/v1" ) func (c *NebiusClient) UpdateInstanceTags(_ context.Context, _ v1.UpdateInstanceTagsArgs) error { diff --git a/internal/validation/suite.go b/internal/validation/suite.go new file mode 100644 index 0000000..ee9b9da --- /dev/null +++ b/internal/validation/suite.go @@ -0,0 +1,128 @@ +package validation + +import ( + "context" + "testing" + "time" + + "github.com/brevdev/cloud/pkg/ssh" + v1 "github.com/brevdev/cloud/pkg/v1" + "github.com/stretchr/testify/require" +) + +type ProviderConfig struct { + Location string + StableIDs []v1.InstanceTypeID + Credential v1.CloudCredential +} + +func RunValidationSuite(t *testing.T, config ProviderConfig) { + if testing.Short() { + t.Skip("Skipping validation tests in short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + client, err := config.Credential.MakeClient(ctx, config.Location) + if err != nil { + t.Fatalf("Failed to create client for %s: %v", config.Credential.GetCloudProviderID(), err) + } + + t.Run("ValidateGetLocations", func(t *testing.T) { + err := v1.ValidateGetLocations(ctx, client) + require.NoError(t, err, "ValidateGetLocations should pass") + }) + + t.Run("ValidateGetInstanceTypes", func(t *testing.T) { + err := v1.ValidateGetInstanceTypes(ctx, client) + require.NoError(t, err, "ValidateGetInstanceTypes should pass") + }) + + t.Run("ValidateRegionalInstanceTypes", func(t *testing.T) { + err := v1.ValidateLocationalInstanceTypes(ctx, client) + require.NoError(t, err, "ValidateRegionalInstanceTypes should pass") + }) + + t.Run("ValidateStableInstanceTypeIDs", func(t *testing.T) { + err = v1.ValidateStableInstanceTypeIDs(ctx, client, config.StableIDs) + require.NoError(t, err, "ValidateStableInstanceTypeIDs should pass") + }) +} + +func RunInstanceLifecycleValidation(t *testing.T, config ProviderConfig) { + if testing.Short() { + t.Skip("Skipping validation tests in short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + client, err := config.Credential.MakeClient(ctx, config.Location) + if err != nil { + t.Fatalf("Failed to create client for %s: %v", config.Credential.GetCloudProviderID(), err) + } + capabilities, err := client.GetCapabilities(ctx) + require.NoError(t, err) + + types, err := client.GetInstanceTypes(ctx, v1.GetInstanceTypeArgs{}) + require.NoError(t, err) + require.NotEmpty(t, types, "Should have instance types") + + locations, err := client.GetLocations(ctx, v1.GetLocationsArgs{}) + require.NoError(t, err) + require.NotEmpty(t, locations, "Should have locations") + + t.Run("ValidateCreateInstance", func(t *testing.T) { + attrs := v1.CreateInstanceAttrs{} + for _, typ := range types { + if typ.IsAvailable { + attrs.InstanceType = typ.Type + attrs.Location = typ.Location + attrs.PublicKey = ssh.GetTestPublicKey() + break + } + } + instance, err := v1.ValidateCreateInstance(ctx, client, attrs) + if err != nil { + t.Fatalf("ValidateCreateInstance failed: %v", err) + } + require.NotNil(t, instance) + + defer func() { + if instance != nil { + _ = client.TerminateInstance(ctx, instance.CloudID) + } + }() + + t.Run("ValidateListCreatedInstance", func(t *testing.T) { + err := v1.ValidateListCreatedInstance(ctx, client, instance) + require.NoError(t, err, "ValidateListCreatedInstance should pass") + }) + + t.Run("ValidateSSHAccessible", func(t *testing.T) { + err := v1.ValidateInstanceSSHAccessible(ctx, client, instance, ssh.GetTestPrivateKey()) + require.NoError(t, err, "ValidateSSHAccessible should pass") + }) + + instance, err = client.GetInstance(ctx, instance.CloudID) + require.NoError(t, err) + + t.Run("ValidateInstanceImage", func(t *testing.T) { + err := v1.ValidateInstanceImage(ctx, *instance, ssh.GetTestPrivateKey()) + require.NoError(t, err, "ValidateInstanceImage should pass") + }) + + if capabilities.IsCapable(v1.CapabilityStopStartInstance) && instance.Stoppable { + t.Run("ValidateStopStartInstance", func(t *testing.T) { + err := v1.ValidateStopStartInstance(ctx, client, instance) + require.NoError(t, err, "ValidateStopStartInstance should pass") + }) + } + + t.Run("ValidateTerminateInstance", func(t *testing.T) { + err := v1.ValidateTerminateInstance(ctx, client, instance) + require.NoError(t, err, "ValidateTerminateInstance should pass") + }) + }) +} diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go new file mode 100644 index 0000000..adba039 --- /dev/null +++ b/pkg/ssh/ssh.go @@ -0,0 +1,482 @@ +package ssh + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "io" + "net" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" +) + +func init() { + setPublicIP(context.Background()) + setLocalIP(context.Background()) +} + +func ConnectToHost(ctx context.Context, config ConnectionConfig) (*Client, error) { + d := net.Dialer{} + sshClient := &Client{ + addr: config.HostPort, + user: config.User, + dial: d.DialContext, + privateKey: config.PrivKey, + } + fmt.Printf("local_ip: %s, public_ip: %s\n", localIP, publicIP) + err := sshClient.Connect(ctx) + if err != nil { + return nil, fmt.Errorf("failed to connect: %w", err) + } + + return sshClient, nil +} + +// modified from https://github.com/superfly/flyctl/blob/master/ssh/client.go +// https://github.com/golang/go/issues/20288#issuecomment-832033017 +type Client struct { + addr string + user string + + dial func(ctx context.Context, network, addr string) (net.Conn, error) + + privateKey, Certificate string + + hostKeyAlgorithms []string + + client *gossh.Client + conn gossh.Conn +} + +func sshBackoff() backoff.BackOff { + b := backoff.NewExponentialBackOff() + b.MaxElapsedTime = 6 * time.Second + b.InitialInterval = 250 * time.Millisecond + b.MaxInterval = 3 * time.Second + return b +} + +func (c *Client) RunCommand(ctx context.Context, cmd string) (string, string, error) { + var stdout, stderr string + var err error + + err = backoff.Retry(func() error { + stdout, stderr, err = c.runCommand(ctx, cmd) + if err != nil && (err == io.EOF || strings.Contains(err.Error(), "unexpected packet in response to channel open: ")) { + cerr := c.Connect(ctx) + if cerr != nil { + return fmt.Errorf("connection error: %w, original error: %w", cerr, err) + } + return err + } else if err != nil { + return backoff.Permanent(err) + } + return nil + }, backoff.WithContext(sshBackoff(), ctx)) + + return stdout, stderr, err +} + +// returns stdout, stderr and error that may be an ssh ExitError +func (c *Client) runCommand(ctx context.Context, cmd string) (string, string, error) { + if c.client == nil { + if err := c.Connect(ctx); err != nil { + return "", "", fmt.Errorf("failed to connect: %w", err) + } + } + + sess, err := c.client.NewSession() + if err != nil { + return "", "", fmt.Errorf("failed to create session, try a new connection: %w", err) + } + defer func() { + if err := sess.Close(); err != nil { + fmt.Printf("failed to close session: %v\n", err) + } + }() + + var stdOutBuffer bytes.Buffer + sess.Stdout = &stdOutBuffer + var stdErrBuffer bytes.Buffer + sess.Stderr = &stdErrBuffer + + err = sess.Run(cmd) + return stdOutBuffer.String(), stdErrBuffer.String(), err +} + +func (c *Client) Close() error { + if c == nil { + return nil + } + if c.conn != nil { + if err := c.conn.Close(); err != nil { + return fmt.Errorf("failed to close connection: %w", err) + } + } + + c.conn = nil + return nil +} + +func (c Client) getSigner() (gossh.Signer, error) { + signer, err := gossh.ParsePrivateKey([]byte(c.privateKey)) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + if c.Certificate != "" { + pubKey, _, _, _, err := gossh.ParseAuthorizedKey([]byte(c.Certificate)) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + cert, ok := pubKey.(*gossh.Certificate) + if !ok { + return nil, fmt.Errorf("SSH public key must be a certificate") + } + signer, err = gossh.NewCertSigner(cert, signer) + if err != nil { + return nil, fmt.Errorf("failed to create cert signer: %w", err) + } + } + + return signer, nil +} + +type connResp struct { + err error + conn gossh.Conn + client *gossh.Client +} + +func (c *Client) Connect(ctx context.Context) error { + signer, err := c.getSigner() + if err != nil { + return fmt.Errorf("failed to get signer: %w", err) + } + + tcpConn, err := c.dial(ctx, "tcp", c.addr) + if err != nil { + return fmt.Errorf("failed to dial: %w", err) + } + + conf := &gossh.ClientConfig{ + User: c.user, + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(signer), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), //nolint:gosec // audited + HostKeyAlgorithms: c.hostKeyAlgorithms, + } + + respCh := make(chan connResp) + + // ssh.NewClientConn doesn't take a context, so we need to handle cancelation on our end + go func() { + conn, chans, reqs, errr := gossh.NewClientConn(tcpConn, tcpConn.RemoteAddr().String(), conf) + if errr != nil { + respCh <- connResp{err: errr} + return + } + + client := gossh.NewClient(conn, chans, reqs) + + respCh <- connResp{nil, conn, client} + }() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case resp := <-respCh: + if resp.err != nil { + return resp.err + } + c.conn = resp.conn + c.client = resp.client + return nil + } + } +} + +//nolint:gosec // WARNING: do not use these keys for anything other than testing +const DoNotUseDummyPrivateKey = `-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCEvcaEC2HvVDV277n6n23KXPwHoWX5mEkuoezqurwSJgq5grQz +Ka3pwdTmRd1CPM9UAXV7aK7UmpMyjSmukmna6CyLXCv61BDrodFb488p4MaPUnwG +FhilkjgcQLBWdHRKcUJZoszdY0kWVWbeUXzSrmTLzuGMmaN32dAXop31CQIDAQAB +AoGACAK33zIcp+fKDjJrY8+JPaQc5Yz87XIeQH0vIf9A6Et5bDaSD2BdiXTUF01y +C9RFoskvwNHRcy0c4vkX4dweHSvHboFAU0ygKU5Dfou1JlmJeK6J+2xrEVGLIyKP +aMWVpyqmDCAKUzO0jEzzDJCZ95KDw9OWS7SBxC9bsRS2soECQQDwyiO6dBWYvRE0 +8GDte+c9MbIdnzYuEeCXyGsK1prAaGLNHoOylp9yXj7M8UKyU+1LOZexjAXvv3tP +imEjHteRAkEAjSBdGf6LAPpGGwK3TuSi2GlJsLWW2trWBuY6+LY9nPnMDCTzYkxt +lD4lkCOxhcB6bNstbL9nBjoo3vHciC85+QJBAOBlUOSLGDFOSUHPnlTTOk1yCa7H +WAOZD3gEA5WHJ5KV9TV48Xy2GAPKRrZRRDnSMvr+whppBoNGLFGVAS9sp7ECQGvj +AumdWzyrF68Me4A3b3qLuwb5O1MiGp55oTmDcESx/liGYv2Rue+rNuIjN1It3Cmd +wPMyu5raGWaedV4y5FkCQQChhO3jMmLXQwCwLVCCfd9duiC1swwpvXm94Byk2h81 +l2FHbn+D8BPAoE/vO/eLAOQVDAgLu0evktWWdtBckUoZ +-----END RSA PRIVATE KEY-----` + +const PubKey = `-----BEGIN PUBLIC KEY----- +MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCEvcaEC2HvVDV277n6n23KXPwH +oWX5mEkuoezqurwSJgq5grQzKa3pwdTmRd1CPM9UAXV7aK7UmpMyjSmukmna6CyL +XCv61BDrodFb488p4MaPUnwGFhilkjgcQLBWdHRKcUJZoszdY0kWVWbeUXzSrmTL +zuGMmaN32dAXop31CQIDAQAB +-----END PUBLIC KEY-----` + +type TestSSHServerOptions struct { + Port string + PubKeyAuth bool + PubKeyDelay time.Duration + ExitCode int +} + +func StartTestSSHServer(options TestSSHServerOptions) (func() error, error) { + handler := func(s ssh.Session) { + authorizedKey := gossh.MarshalAuthorizedKey(s.PublicKey()) + _, err := io.WriteString(s, fmt.Sprintf("public key used by %s:\n", s.User())) // writes to client output + if err != nil { + fmt.Println(err) + } + _, err = s.Write(authorizedKey) + if err != nil { + fmt.Println(err) + } + _, err = s.Write([]byte(s.RawCommand())) + if err != nil { + fmt.Println(err) + } + err = s.Exit(options.ExitCode) + if err != nil { + fmt.Println(err) + } + } + + publicKeyOption := ssh.PublicKeyAuth(func(_ ssh.Context, _ ssh.PublicKey) bool { + time.Sleep(options.PubKeyDelay) + return options.PubKeyAuth // allow all keys, or use ssh.KeysEqual() to compare against known keys + }) + + server := ssh.Server{ + Addr: fmt.Sprintf(":%s", options.Port), + } + server.Handler = handler + err := server.SetOption(publicKeyOption) + if err != nil { + return nil, fmt.Errorf("failed to set public key option: %w", err) + } + + go func() { + err1 := server.ListenAndServe() + if err1 != nil { + fmt.Println(err1) + } + }() + time.Sleep(100 * time.Millisecond) + return server.Close, nil +} + +type ConnectionConfig struct { + User, HostPort, PrivKey string +} + +type WaitForSSHOptions struct { + Timeout time.Duration + ConnectionTimeout time.Duration + CheckCmd string + WaitTime time.Duration +} + +func (o *WaitForSSHOptions) SetDefault() { + if o.Timeout == 0 { + o.Timeout = 60 * time.Second + } + if o.ConnectionTimeout == 0 { + o.ConnectionTimeout = 30 * time.Second + } + if o.CheckCmd == "" { + o.CheckCmd = "echo 'connected'" // WARNING: assumes echo command exists + } + if o.WaitTime == 0 { + o.WaitTime = 1 * time.Second + } +} + +func WaitForSSH(ctx context.Context, c ConnectionConfig, options WaitForSSHOptions) error { + options.SetDefault() + errChan := make(chan error, 1) + errChan <- nil + t0 := time.Now() + + // Create a timeout context + timeoutCtx, cancel := context.WithTimeout(ctx, options.Timeout) + defer cancel() + + err := doWithTimeout(timeoutCtx, func(ctx context.Context) error { + waitForSSH(ctx, errChan, c, options) + return nil + }) + + lastSSHErr := <-errChan + t1 := time.Now() + if err != nil { + err = fmt.Errorf("took %s: %w", t1.Sub(t0).String(), err) + if lastSSHErr != nil { + lastSSHErr = fmt.Errorf("last error: %w", lastSSHErr) + return fmt.Errorf("%w, %w", err, lastSSHErr) + } + return err + } + return nil +} + +func doWithTimeout(ctx context.Context, fn func(context.Context) error) error { + done := make(chan error, 1) + go func() { + defer func() { + if r := recover(); r != nil { + done <- fmt.Errorf("panic recovered: %v", r) + } + }() + done <- fn(ctx) + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +func waitForSSH(ctx context.Context, errChan chan error, c ConnectionConfig, options WaitForSSHOptions) { + for ctx.Err() == nil { + _ = <-errChan + tryCtx, cancel := context.WithTimeout(ctx, options.ConnectionTimeout) + sshErr := TrySSHConnect(tryCtx, c, options) + cancel() + errChan <- sshErr + if sshErr == nil { + return + } + time.Sleep(options.WaitTime) + } +} + +func TrySSHConnect(ctx context.Context, c ConnectionConfig, options WaitForSSHOptions) error { + con, err := ConnectToHost(ctx, c) + if err != nil { + return fmt.Errorf("failed to connect to host: %w", err) + } + defer func() { + if closeErr := con.Close(); closeErr != nil { + // Log close error but don't return it as it's not the primary error + fmt.Printf("warning: failed to close SSH connection: %v\n", closeErr) + } + }() + _, _, err = con.RunCommand(ctx, options.CheckCmd) + if err != nil { + return fmt.Errorf("failed to run check command: %w", err) + } + return nil +} + +var ( + localIP string + localIPOnce sync.Once +) + +func setLocalIP(ctx context.Context) { + localIPOnce.Do(func() { + ip, err := GetLocalIP(ctx) + if err != nil { + fmt.Printf("failed to get local IP: %v\n", err) + return + } + localIP = ip.String() + }) +} + +func GetLocalIP(ctx context.Context) (net.IP, error) { + dialer := net.Dialer{} + conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:80") + if err != nil { + return nil, fmt.Errorf("failed to dial for local IP: %w", err) + } + defer func() { + if err := conn.Close(); err != nil { + fmt.Printf("failed to close connection: %v\n", err) + } + }() + + localAddr, ok := conn.LocalAddr().(*net.UDPAddr) + if !ok { + return nil, fmt.Errorf("error getting local IP") + } + return localAddr.IP, nil +} + +var ( + publicIP string + publicIPOnce sync.Once +) + +// setPublicIP retrieves and sets the public IP address once. +// It uses sync.Once to ensure the IP is only fetched one time. +func setPublicIP(ctx context.Context) { + publicIPOnce.Do(func() { + ip, err := GetPublicIPStr(ctx) + if err != nil { + fmt.Printf("failed to get public IP: %v\n", err) + return + } + publicIP = ip + }) +} + +func GetPublicIPStr(ctx context.Context) (string, error) { + req, err := http.NewRequestWithContext(ctx, "GET", "https://api.ipify.org", nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to get public IP: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("failed to close response body: %v\n", err) + } + }() + + ip, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } + + return string(ip), nil +} + +func GetTestPrivateKey() string { + privateKey, err := base64.StdEncoding.DecodeString(os.Getenv("TEST_PRIVATE_KEY_BASE64")) + if err != nil { + panic(err) + } + return string(privateKey) +} + +func GetTestPublicKey() string { + pubKey, err := base64.StdEncoding.DecodeString(os.Getenv("TEST_PUBLIC_KEY_BASE64")) + if err != nil { + panic(err) + } + return string(pubKey) +} diff --git a/pkg/ssh/ssh_test.go b/pkg/ssh/ssh_test.go new file mode 100644 index 0000000..9bfa4d5 --- /dev/null +++ b/pkg/ssh/ssh_test.go @@ -0,0 +1,347 @@ +package ssh + +import ( + "context" + "fmt" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + gossh "golang.org/x/crypto/ssh" +) + +var nextPort int32 = 3334 + +func getPort() string { + n := atomic.AddInt32(&nextPort, 1) + return fmt.Sprintf("%d", n) +} + +func Test_ConnectToHostSuccess(t *testing.T) { + t.Parallel() + ctx := context.Background() + port := getPort() + done, err := StartTestSSHServer( + TestSSHServerOptions{ + Port: port, + PubKeyAuth: true, + }, + ) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := done(); err != nil { + t.Fatal(err) + } + }() + + // res, err := exec.Command("ssh", "-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no", "-p", port, "localhost").CombinedOutput() + // mt.Println(string(res)) + _, err = ConnectToHost(ctx, ConnectionConfig{"ubuntu", fmt.Sprintf("localhost:%s", port), DoNotUseDummyPrivateKey}) + if !assert.NoError(t, err) { + return + } +} + +func Test_ConnectToHostRunCommandExit0(t *testing.T) { + t.Parallel() + ctx := context.Background() + port := getPort() + done, err := StartTestSSHServer( + TestSSHServerOptions{ + Port: port, + PubKeyAuth: true, + }, + ) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := done(); err != nil { + t.Fatal(err) + } + }() + + c, err := ConnectToHost(ctx, ConnectionConfig{"ubuntu", fmt.Sprintf("localhost:%s", port), DoNotUseDummyPrivateKey}) + if !assert.NoError(t, err) { + return + } + stdOut, stdErr, err := c.RunCommand(ctx, "echo hello") + if !assert.NoError(t, err) { + return + } + assert.Contains(t, stdOut, "echo hello") + assert.Empty(t, stdErr) +} + +func Test_ConnectToHostRunCommandExit1(t *testing.T) { + t.Parallel() + ctx := context.Background() + port := getPort() + done, err := StartTestSSHServer( + TestSSHServerOptions{ + Port: port, + PubKeyAuth: true, + ExitCode: 1, + }, + ) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := done(); err != nil { + t.Fatal(err) + } + }() + + c, err := ConnectToHost(ctx, ConnectionConfig{"ubuntu", fmt.Sprintf("localhost:%s", port), DoNotUseDummyPrivateKey}) + if !assert.NoError(t, err) { + return + } + stdOut, stdErr, err := c.RunCommand(ctx, "echo hello") + if !assert.Error(t, err) { + if !assert.ErrorIs(t, err, &gossh.ExitError{}) { + res, _ := err.(*gossh.ExitError) + assert.Equal(t, res.ExitStatus(), 1) + return + } + return + } + assert.Contains(t, stdOut, "echo hello") + assert.Empty(t, stdErr) +} + +func Test_ConnectToHostPubKeyFail(t *testing.T) { + t.Parallel() + ctx := context.Background() + port := getPort() + done, err := StartTestSSHServer( + TestSSHServerOptions{ + Port: port, + PubKeyAuth: false, + }, + ) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := done(); err != nil { + t.Fatal(err) + } + }() + + _, err = ConnectToHost(ctx, ConnectionConfig{"ubuntu", fmt.Sprintf("localhost:%s", port), DoNotUseDummyPrivateKey}) + assert.ErrorContains(t, err, "unable to authenticate") +} + +func Test_FailConnectionRefusedSSH(t *testing.T) { + t.Parallel() + // no server running + ctx := context.Background() + _, err := ConnectToHost(ctx, ConnectionConfig{"ubuntu", fmt.Sprintf("localhost:%s", "3333"), DoNotUseDummyPrivateKey}) + assert.ErrorContains(t, err, "connection refused") +} + +func Test_TestSSHTimeoutKey(t *testing.T) { + t.Parallel() + // no server running + ctx := context.Background() + timeout := time.Millisecond * 250 + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + port := getPort() + done, err := StartTestSSHServer(TestSSHServerOptions{ + Port: port, + PubKeyAuth: true, + PubKeyDelay: timeout * 2, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := done(); err != nil { + t.Fatal(err) + } + }() + + _, err = ConnectToHost(ctx, ConnectionConfig{"ubuntu", fmt.Sprintf("localhost:%s", port), DoNotUseDummyPrivateKey}) + // assert.ErrorContains(t, err, "timeout") + assert.ErrorContains(t, err, "context deadline exceeded") +} + +func Test_CertficiateSSH(t *testing.T) { + t.Parallel() + // fails: don't understand how to use certificates + t.Skip() + c := Client{ + privateKey: DoNotUseDummyPrivateKey, + Certificate: PubKey, + } + _, err := c.getSigner() + assert.NoError(t, err) +} + +func Test_TrySSHConnect(t *testing.T) { + t.Parallel() + ctx := context.Background() + port := getPort() + done, err := StartTestSSHServer( + TestSSHServerOptions{ + Port: port, + PubKeyAuth: true, + ExitCode: 0, + }, + ) + if !assert.NoError(t, err) { + return + } + defer func() { + if err := done(); err != nil { + t.Fatal(err) + } + }() + err = TrySSHConnect(ctx, ConnectionConfig{ + User: "ubuntu", + HostPort: fmt.Sprintf("localhost:%s", port), + PrivKey: DoNotUseDummyPrivateKey, + }, WaitForSSHOptions{}) + assert.NoError(t, err) +} + +func Test_WaitForSSH(t *testing.T) { + t.Parallel() + ctx := context.Background() + port := getPort() + done, err := StartTestSSHServer( + TestSSHServerOptions{ + Port: port, + PubKeyAuth: true, + ExitCode: 0, + }, + ) + if !assert.NoError(t, err) { + return + } + defer func() { + if err := done(); err != nil { + t.Fatal(err) + } + }() + err = WaitForSSH(ctx, ConnectionConfig{ + User: "ubuntu", + HostPort: fmt.Sprintf("localhost:%s", port), + PrivKey: DoNotUseDummyPrivateKey, + }, WaitForSSHOptions{}) + assert.NoError(t, err) +} + +func Test_WaitForSSHFailWithRetry(t *testing.T) { + RetryTest(t, WaitForSSHFailFlaky, 3) +} + +func WaitForSSHFailFlaky(t *testing.T) { + t.Helper() + ctx := context.Background() + port := getPort() + done, err := StartTestSSHServer( + TestSSHServerOptions{ + Port: port, + PubKeyAuth: true, + ExitCode: 0, + }, + ) + if !assert.NoError(t, err) { + return + } + defer func() { + if err := done(); err != nil { + t.Fatal(err) + } + }() + t0 := time.Now() + err = WaitForSSH(ctx, ConnectionConfig{ + User: "ubuntu", + HostPort: fmt.Sprintf("localhost:%s", "3333"), + PrivKey: DoNotUseDummyPrivateKey, + }, WaitForSSHOptions{ + Timeout: time.Second * 5, + ConnectionTimeout: time.Second * 1, + }) + t1 := time.Now() + assert.ErrorContains(t, err, "context deadline exceeded") + assert.ErrorContains(t, err, "last error") + assert.ErrorContains(t, err, "connection refuse") + assert.Greater(t, t1.Sub(t0), time.Second*5) +} + +func Test_GetCallerIP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ctx context.Context + wantErr bool + ipCheck func(t *testing.T, ip net.IP) + }{ + { + name: "valid context", + ctx: context.Background(), + wantErr: false, + ipCheck: func(t *testing.T, ip net.IP) { + t.Helper() + t.Logf("ip: %s", ip) + assert.NotNil(t, ip) + assert.True(t, ip.IsGlobalUnicast() || ip.IsPrivate(), "IP should be either public or private") + assert.False(t, ip.IsUnspecified(), "IP should not be unspecified (0.0.0.0)") + assert.False(t, ip.IsLoopback(), "IP should not be loopback (127.0.0.1)") + assert.False(t, ip.IsMulticast(), "IP should not be multicast") + }, + }, + { + name: "canceled context", + ctx: func() context.Context { ctx, cancel := context.WithCancel(context.Background()); cancel(); return ctx }(), + wantErr: true, + ipCheck: func(t *testing.T, ip net.IP) { + t.Helper() + assert.Nil(t, ip) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ip, err := GetLocalIP(tt.ctx) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + tt.ipCheck(t, ip) + }) + } +} + +func Test_GetPublicIP(t *testing.T) { + t.Parallel() + ctx := context.Background() + ip, err := GetPublicIPStr(ctx) + assert.NoError(t, err) + assert.NotEmpty(t, ip) +} + +func RetryTest(t *testing.T, testFunc func(t *testing.T), numRetries int) { + t.Helper() // Mark this function as a helper + for i := 0; i < numRetries; i++ { + tt := &testing.T{} + testFunc(tt) + if !tt.Failed() { + return + } + } + t.Fail() // If we reach here, all retries failed +} diff --git a/pkg/v1/V1_DESIGN_NOTES.md b/pkg/v1/V1_DESIGN_NOTES.md index d0e4f7e..339b8c4 100644 --- a/pkg/v1/V1_DESIGN_NOTES.md +++ b/pkg/v1/V1_DESIGN_NOTES.md @@ -44,3 +44,10 @@ The terminology around instance-attached storage is one of the more confusing pa - Some clouds (e.g. AWS) treat root and attached volumes differently (with separate APIs). - Others (e.g. Lambda) don’t expose volumes at all — only a total storage value. - Elastic volumes, ephemeral disks, and NVMe local storage are not modeled cleanly in v1. + +### Cluster Support Limitations +- The v1 design is fundamentally instance-centric and not conducive to cluster support. +- No abstractions for cluster-level operations, networking, or orchestration. +- Instance management is treated as individual resources rather than as part of a larger distributed system. +- Missing concepts like cluster membership, inter-instance communication, shared state, or cluster lifecycle management. +- For support to be added we may need to more fomally implement networks/vpcs or instance groups. diff --git a/pkg/v1/image.go b/pkg/v1/image.go index 5ce86f9..9683262 100644 --- a/pkg/v1/image.go +++ b/pkg/v1/image.go @@ -2,7 +2,12 @@ package v1 import ( "context" + "fmt" + "regexp" + "strings" "time" + + "github.com/brevdev/cloud/pkg/ssh" ) type CloudMachineImage interface { @@ -23,3 +28,86 @@ type Image struct { Name string CreatedAt time.Time } + +func ValidateInstanceImage(ctx context.Context, instance Instance, privateKey string) error { + // First ensure the instance is running and SSH accessible + sshUser := instance.SSHUser + sshPort := instance.SSHPort + publicIP := instance.PublicIP + + // Validate that we have the required SSH connection details + if sshUser == "" { + return fmt.Errorf("SSH user is not set for instance %s", instance.CloudID) + } + if sshPort == 0 { + return fmt.Errorf("SSH port is not set for instance %s", instance.CloudID) + } + if publicIP == "" { + return fmt.Errorf("public IP is not available for instance %s", instance.CloudID) + } + + // Connect to the instance via SSH + sshClient, err := ssh.ConnectToHost(ctx, ssh.ConnectionConfig{ + User: sshUser, + HostPort: fmt.Sprintf("%s:%d", publicIP, sshPort), + PrivKey: privateKey, + }) + if err != nil { + return fmt.Errorf("failed to connect to instance via SSH: %w", err) + } + defer func() { + if closeErr := sshClient.Close(); closeErr != nil { + // Log close error but don't return it as it's not the primary error + fmt.Printf("warning: failed to close SSH connection: %v\n", closeErr) + } + }() + + // Check 1: Verify x86_64 architecture + stdout, stderr, err := sshClient.RunCommand(ctx, "uname -m") + if err != nil { + return fmt.Errorf("failed to check architecture: %w, stdout: %s, stderr: %s", err, stdout, stderr) + } + if !strings.Contains(strings.TrimSpace(stdout), "x86_64") { + return fmt.Errorf("expected x86_64 architecture, got: %s", strings.TrimSpace(stdout)) + } + + // Check 2: Verify Ubuntu 20.04 or 22.04 + stdout, stderr, err = sshClient.RunCommand(ctx, "cat /etc/os-release | grep PRETTY_NAME") + if err != nil { + return fmt.Errorf("failed to check OS version: %w, stdout: %s, stderr: %s", err, stdout, stderr) + } + + parts := strings.Split(strings.TrimSpace(stdout), "=") + if len(parts) != 2 { + return fmt.Errorf("error: os pretty name not in format PRETTY_NAME=\"Ubuntu\": %s", stdout) + } + + // Remove quotes from the value + osVersion := strings.Trim(parts[1], "\"") + ubuntuRegex := regexp.MustCompile(`Ubuntu 20\.04|22\.04`) + if !ubuntuRegex.MatchString(osVersion) { + return fmt.Errorf("expected Ubuntu 20.04 or 22.04, got: %s", osVersion) + } + + // Check 3: Verify home directory + stdout, stderr, err = sshClient.RunCommand(ctx, "cd ~ && pwd") + if err != nil { + return fmt.Errorf("failed to check home directory: %w, stdout: %s, stderr: %s", err, stdout, stderr) + } + + homeDir := strings.TrimSpace(stdout) + if sshUser == "ubuntu" { + if !strings.Contains(homeDir, "/home/ubuntu") { + return fmt.Errorf("expected ubuntu user home directory to contain /home/ubuntu, got: %s", homeDir) + } + } else { + if !strings.Contains(homeDir, "/root") { + return fmt.Errorf("expected non-ubuntu user home directory to contain /root, got: %s", homeDir) + } + } + + fmt.Printf("Instance image validation passed for %s: architecture=%s, os=%s, home=%s\n", + instance.CloudID, "x86_64", osVersion, homeDir) + + return nil +} diff --git a/pkg/v1/instance.go b/pkg/v1/instance.go index 248cca4..199d83a 100644 --- a/pkg/v1/instance.go +++ b/pkg/v1/instance.go @@ -7,22 +7,29 @@ import ( "time" "github.com/alecthomas/units" + "github.com/brevdev/cloud/internal/collections" + "github.com/brevdev/cloud/pkg/ssh" "github.com/google/uuid" ) +type CloudInstanceReader interface { + GetInstance(ctx context.Context, id CloudProviderInstanceID) (*Instance, error) + ListInstances(ctx context.Context, args ListInstancesArgs) ([]Instance, error) +} + type CloudCreateTerminateInstance interface { // CreateInstance expects an instance object to exist if successful, and no instance to exist if there is ANY error // CloudClient Implementers: ensure that the instance is terminated if there is an error // Public ip is not always returned from create, but will exist when instance is in running state CreateInstance(ctx context.Context, attrs CreateInstanceAttrs) (*Instance, error) - GetInstance(ctx context.Context, id CloudProviderInstanceID) (*Instance, error) // may or may not be locationally scoped TerminateInstance(ctx context.Context, instanceID CloudProviderInstanceID) error // may or may not be locationally scoped - ListInstances(ctx context.Context, args ListInstancesArgs) ([]Instance, error) // return all known instances from cloud api perspective + GetMaxCreateRequestsPerMinute() int CloudInstanceType + CloudInstanceReader } func ValidateCreateInstance(ctx context.Context, client CloudCreateTerminateInstance, attrs CreateInstanceAttrs) (*Instance, error) { - t0 := time.Now() + t0 := time.Now().Add(-time.Minute) attrs.RefID = uuid.New().String() name, err := makeDebuggableName(attrs.Name) if err != nil { @@ -34,9 +41,9 @@ func ValidateCreateInstance(ctx context.Context, client CloudCreateTerminateInst return nil, err } var validationErr error - t1 := time.Now() + t1 := time.Now().Add(1 * time.Minute) diff := t1.Sub(t0) - if diff > 1*time.Minute { + if diff > 3*time.Minute { validationErr = errors.Join(validationErr, fmt.Errorf("create instance took too long: %s", diff)) } if i.CreatedAt.Before(t0) { @@ -46,7 +53,7 @@ func ValidateCreateInstance(ctx context.Context, client CloudCreateTerminateInst validationErr = errors.Join(validationErr, fmt.Errorf("createdAt is after t1: %s", i.CreatedAt)) } if i.Name != name { - validationErr = errors.Join(validationErr, fmt.Errorf("name mismatch: %s != %s", i.Name, name)) + fmt.Printf("name mismatch: %s != %s, input name does not mean return name will be stable\n", i.Name, name) } if i.RefID != attrs.RefID { validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", i.RefID, attrs.RefID)) @@ -75,30 +82,26 @@ func ValidateListCreatedInstance(ctx context.Context, client CloudCreateTerminat if len(ins) == 0 { validationErr = errors.Join(validationErr, fmt.Errorf("no instances found")) } - if ins[0].Location != i.Location { - validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", ins[0].Location, i.Location)) - } - instanceIDsMap := map[CloudProviderInstanceID]Instance{} - for _, inst := range ins { - instanceIDsMap[inst.CloudID] = inst - } - inst, ok := instanceIDsMap[i.CloudID] - if !ok { + foundInstance := collections.Find(ins, func(inst Instance) bool { + return inst.CloudID == i.CloudID + }) + if foundInstance == nil { validationErr = errors.Join(validationErr, fmt.Errorf("instance not found: %s", i.CloudID)) - return validationErr } - if inst.RefID != i.RefID { - validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", inst.RefID, i.RefID)) + if foundInstance.Location != i.Location { + validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", foundInstance.Location, i.Location)) + } else if foundInstance.RefID != i.RefID { + validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", foundInstance.RefID, i.RefID)) } return validationErr } -func ValidateTerminateInstance(ctx context.Context, client CloudCreateTerminateInstance, instance Instance) error { +func ValidateTerminateInstance(ctx context.Context, client CloudCreateTerminateInstance, instance *Instance) error { err := client.TerminateInstance(ctx, instance.CloudID) if err != nil { return err } - // TODO wait for terminated + // TODO wait for instance to go into terminating state return nil } @@ -107,7 +110,7 @@ type CloudStopStartInstance interface { StartInstance(ctx context.Context, instanceID CloudProviderInstanceID) error } -func ValidateStopStartInstance(ctx context.Context, client CloudStopStartInstance, instance Instance) error { +func ValidateStopStartInstance(ctx context.Context, client CloudStopStartInstance, instance *Instance) error { err := client.StopInstance(ctx, instance.CloudID) if err != nil { return err @@ -238,6 +241,13 @@ const ( LifecycleStatusFailed LifecycleStatus = "failed" ) +const ( + PendingToRunningTimeout = 20 * time.Minute + RunningToStoppedTimeout = 10 * time.Minute + StoppedToRunningTimeout = 20 * time.Minute + RunningToTerminatedTimeout = 20 * time.Minute +) + type CloudProviderInstanceID string type ListInstancesArgs struct { @@ -284,3 +294,41 @@ func makeDebuggableName(name string) (string, error) { } return fmt.Sprintf("%s-%s", name, time.Now().In(pt).Format("2006-01-02-15-04-05")), nil } + +const RunningSSHTimeout = 10 * time.Minute + +func ValidateInstanceSSHAccessible(ctx context.Context, client CloudInstanceReader, instance *Instance, privateKey string) error { + var err error + instance, err = WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, PendingToRunningTimeout) + if err != nil { + return err + } + sshUser := instance.SSHUser + sshPort := instance.SSHPort + publicIP := instance.PublicIP + // Validate that we have the required SSH connection details + if sshUser == "" { + return fmt.Errorf("SSH user is not set for instance %s", instance.CloudID) + } + if sshPort == 0 { + return fmt.Errorf("SSH port is not set for instance %s", instance.CloudID) + } + if publicIP == "" { + return fmt.Errorf("public IP is not available for instance %s", instance.CloudID) + } + + err = ssh.WaitForSSH(ctx, ssh.ConnectionConfig{ + User: sshUser, + HostPort: fmt.Sprintf("%s:%d", publicIP, sshPort), + PrivKey: privateKey, + }, ssh.WaitForSSHOptions{ + Timeout: RunningSSHTimeout, + }) + if err != nil { + return err + } + + fmt.Printf("SSH connection validated successfully for %s@%s:%d\n", sshUser, publicIP, sshPort) + + return nil +} diff --git a/pkg/v1/instancetype.go b/pkg/v1/instancetype.go index 0180c3e..5377718 100644 --- a/pkg/v1/instancetype.go +++ b/pkg/v1/instancetype.go @@ -9,6 +9,7 @@ import ( "github.com/alecthomas/units" "github.com/bojanz/currency" + "github.com/google/go-cmp/cmp" ) type InstanceTypeID string @@ -47,6 +48,28 @@ type InstanceType struct { CanModifyFirewallRules bool } +func MakeGenericInstanceTypeID(instanceType InstanceType) InstanceTypeID { + if instanceType.ID != "" { + return instanceType.ID + } + subLoc := noSubLocation + if len(instanceType.AvailableAzs) > 0 { + subLoc = instanceType.AvailableAzs[0] + } + return InstanceTypeID(fmt.Sprintf("%s-%s-%s", instanceType.Location, subLoc, instanceType.Type)) +} + +func MakeGenericInstanceTypeIDFromInstance(instance Instance) InstanceTypeID { + if instance.InstanceTypeID != "" { + return instance.InstanceTypeID + } + subLoc := noSubLocation + if instance.SubLocation != "" { + subLoc = instance.SubLocation + } + return InstanceTypeID(fmt.Sprintf("%s-%s-%s", instance.Location, subLoc, instance.InstanceType)) +} + type GPU struct { Count int32 Memory units.Base2Bytes @@ -77,7 +100,7 @@ type GetInstanceTypeArgs struct { // ValidateGetInstanceTypes validates that the GetInstanceTypes functionality works correctly // by testing that filtering by specific instance types returns the expected results -func ValidateGetInstanceTypes(ctx context.Context, client CloudInstanceType) error { +func ValidateGetInstanceTypes(ctx context.Context, client CloudInstanceType) error { //nolint:funlen,gocyclo // todo refactor // Get all instance types first allTypes, err := client.GetInstanceTypes(ctx, GetInstanceTypeArgs{}) if err != nil { @@ -88,7 +111,7 @@ func ValidateGetInstanceTypes(ctx context.Context, client CloudInstanceType) err return errors.New("no instance types available for validation") } - // Test 1: Deterministic results - multiple calls should return the same results + // Test 1: Deterministic results - multiple calls should return the same results (order-insensitive) allTypes2, err := client.GetInstanceTypes(ctx, GetInstanceTypeArgs{}) if err != nil { return fmt.Errorf("failed to get all instance types on second call: %w", err) @@ -98,8 +121,30 @@ func ValidateGetInstanceTypes(ctx context.Context, client CloudInstanceType) err normalizedTypes1 := normalizeInstanceTypes(allTypes) normalizedTypes2 := normalizeInstanceTypes(allTypes2) - if !reflect.DeepEqual(normalizedTypes1, normalizedTypes2) { - return fmt.Errorf("instance types are not deterministic between calls") + // Build maps keyed by ID for order-insensitive comparison + map1 := make(map[InstanceTypeID]InstanceType) + for _, t := range normalizedTypes1 { + map1[t.ID] = t + } + map2 := make(map[InstanceTypeID]InstanceType) + for _, t := range normalizedTypes2 { + map2[t.ID] = t + } + + // Compare keys + if len(map1) != len(map2) { + return fmt.Errorf("instance types are not deterministic between calls: different number of types (%d vs %d)", len(map1), len(map2)) + } + for id, t1 := range map1 { + t2, ok := map2[id] + if !ok { + return fmt.Errorf("instance type ID %s present in first call but missing in second", id) + } + if !reflect.DeepEqual(t1, t2) { + diff := cmp.Diff(t1, t2) + fmt.Printf("Instance type with ID %s differs between calls. Diff:\n%s\n", id, diff) + return fmt.Errorf("instance type with ID %s differs between calls", id) + } } // Test 2: ID stability and uniqueness @@ -132,68 +177,86 @@ func ValidateGetInstanceTypes(ctx context.Context, client CloudInstanceType) err expectedType.SubLocation = "" expectedType.AvailableAzs = nil - actualType := filteredTypes[0] - actualType.ID = "" - actualType.SubLocation = "" - actualType.AvailableAzs = nil - - // Use reflection to compare the structs - if !reflect.DeepEqual(expectedType, actualType) { + // Find the matching type in filteredTypes by ID (since order is not guaranteed) + var actualType InstanceType + found := false + for _, t := range filteredTypes { + tmp := t + tmp.ID = "" + tmp.SubLocation = "" + tmp.AvailableAzs = nil + if reflect.DeepEqual(expectedType, tmp) { + actualType = tmp + found = true + break + } + } + if !found { + // If not found by struct equality, just compare the first filtered type for debugging + actualType = filteredTypes[0] + actualType.ID = "" + actualType.SubLocation = "" + actualType.AvailableAzs = nil + diff := cmp.Diff(expectedType, actualType) + fmt.Printf("Filtered instance type does not match expected type. Diff:\n%s\n", diff) return fmt.Errorf("filtered instance type does not match expected type: expected %+v, got %+v", expectedType, actualType) } return nil } -// ValidateRegionalInstanceTypes validates that regional filtering works correctly -// by comparing regional results with all-region results using CloudLocation capabilities -func ValidateRegionalInstanceTypes(ctx context.Context, client CloudInstanceType) error { - // Get regional instance types (default behavior - typically current region) - regionalTypes, err := client.GetInstanceTypes(ctx, GetInstanceTypeArgs{}) +// ValidateLocationalInstanceTypes validates that locational filtering works correctly +// by comparing locational results with all-location results using CloudLocation capabilities +func ValidateLocationalInstanceTypes(ctx context.Context, client CloudInstanceType) error { + // Get all-location instance types by requesting from all locations + allLocationTypes, err := client.GetInstanceTypes(ctx, GetInstanceTypeArgs{ + Locations: All, + }) if err != nil { - return fmt.Errorf("failed to get regional instance types: %w", err) + // If all-location is not supported, skip this validation + return fmt.Errorf("all-location instance types not supported: %w", err) } - if len(regionalTypes) == 0 { - return errors.New("no regional instance types available for validation") + if len(allLocationTypes) == 0 { + return errors.New("no all-location instance types available for validation") } - // Get all-region instance types by requesting from all locations - allRegionTypes, err := client.GetInstanceTypes(ctx, GetInstanceTypeArgs{ - Locations: All, + locationToTest := allLocationTypes[0].Location + // Get locational instance types (default behavior - typically current location) + locationalTypes, err := client.GetInstanceTypes(ctx, GetInstanceTypeArgs{ + Locations: LocationsFilter{locationToTest}, }) if err != nil { - // If all-region is not supported, skip this validation - return fmt.Errorf("all-region instance types not supported: %w", err) + return fmt.Errorf("failed to get locational instance types: %w", err) } - if len(allRegionTypes) == 0 { - return errors.New("no all-region instance types available for validation") + if len(locationalTypes) == 0 { + return errors.New("no locational instance types available for validation") } - // Validate that regional results are a subset of all-region results - if len(regionalTypes) >= len(allRegionTypes) { - return fmt.Errorf("regional instance types (%d) should be fewer than all-region types (%d)", - len(regionalTypes), len(allRegionTypes)) + // Validate that locational results are a subset of all-location results + if len(locationalTypes) >= len(allLocationTypes) { + return fmt.Errorf("locational instance types (%d) should be fewer than all-location types (%d)", + len(locationalTypes), len(allLocationTypes)) } - // Create a map of all-region types for efficient lookup - allRegionMap := make(map[InstanceTypeID]InstanceType) - for _, instanceType := range allRegionTypes { - allRegionMap[instanceType.ID] = instanceType + // Create a map of all-location types for efficient lookup + allLocationMap := make(map[InstanceTypeID]InstanceType) + for _, instanceType := range allLocationTypes { + allLocationMap[instanceType.ID] = instanceType } - // Validate that all regional types exist in all-region results - for _, regionalType := range regionalTypes { - if _, exists := allRegionMap[regionalType.ID]; !exists { - return fmt.Errorf("regional instance type %s not found in all-region results", regionalType.ID) + // Validate that all locational types exist in all-location results + for _, locationalType := range locationalTypes { + if _, exists := allLocationMap[locationalType.ID]; !exists { + return fmt.Errorf("locational instance type %s not found in all-location results", locationalType.ID) } } - // Additional validation: ensure regional types have appropriate location information - for _, regionalType := range regionalTypes { - if regionalType.Location == "" { - return fmt.Errorf("regional instance type %s should have location information", regionalType.ID) + // Additional validation: ensure locational types have appropriate location information + for _, locationalType := range locationalTypes { + if locationalType.Location == "" { + return fmt.Errorf("locational instance type %s should have location information", locationalType.ID) } } @@ -244,12 +307,16 @@ func ValidateStableInstanceTypeIDs(ctx context.Context, client CloudInstanceType return errors.New("stable IDs list cannot be empty") } - // Validate that all stable IDs exist in current instance types + // Validate that all stable IDs exist in current instance types, collecting all errors + var errs []error for _, stableID := range stableIDs { if _, exists := typesByID[stableID]; !exists { - return fmt.Errorf("instance type id %s should be stable but not found", stableID) + errs = append(errs, fmt.Errorf("instance type id %s should be stable but not found", stableID)) // if this fails, we may need to coordinate a migration of the stable ID } } + if len(errs) > 0 { + return errors.Join(errs...) + } // Validate that all instance types have required properties for _, instanceType := range allTypes { diff --git a/pkg/v1/location.go b/pkg/v1/location.go index 3c431ae..5518595 100644 --- a/pkg/v1/location.go +++ b/pkg/v1/location.go @@ -57,3 +57,5 @@ func ValidateGetLocations(ctx context.Context, client CloudLocation) error { } return nil } + +const noSubLocation = "noSub" diff --git a/pkg/v1/notimplemented.go b/pkg/v1/notimplemented.go index 3e82cea..afe02d7 100644 --- a/pkg/v1/notimplemented.go +++ b/pkg/v1/notimplemented.go @@ -122,3 +122,7 @@ func (c notImplCloudClient) MergeInstanceForUpdate(_, i Instance) Instance { func (c notImplCloudClient) MergeInstanceTypeForUpdate(_, i InstanceType) InstanceType { return i } + +func (c notImplCloudClient) GetMaxCreateRequestsPerMinute() int { + return 10 +} diff --git a/pkg/v1/waiters.go b/pkg/v1/waiters.go new file mode 100644 index 0000000..baf3de4 --- /dev/null +++ b/pkg/v1/waiters.go @@ -0,0 +1,92 @@ +package v1 + +import ( + "context" + "time" +) + +func WaitForInstanceLifecycleStatus(ctx context.Context, + client CloudInstanceReader, + instance *Instance, + status LifecycleStatus, + timeout time.Duration, +) (*Instance, error) { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + timeoutCh := time.After(timeout) + var lastInstance *Instance + var lastErr error + + for { + select { + case <-ctx.Done(): + return instance, ctx.Err() + case <-timeoutCh: + if lastInstance != nil { + return lastInstance, &InstanceWaitTimeoutError{ + Instance: lastInstance, + Desired: status, + Err: lastErr, + } + } + return instance, &InstanceWaitTimeoutError{ + Instance: instance, + Desired: status, + Err: lastErr, + } + case <-ticker.C: + inst, err := client.GetInstance(ctx, instance.CloudID) + if err != nil { + // If instance is not found, return error immediately + return inst, &InstanceWaitNotFoundError{ + InstanceID: instance.CloudID, + Err: err, + } + } + lastInstance = inst + if inst.Status.LifecycleStatus == status { + return inst, nil + } + } + } +} + +// InstanceWaitTimeoutError is returned when waiting for an instance times out. +type InstanceWaitTimeoutError struct { + Instance *Instance + Desired LifecycleStatus + Err error +} + +func (e *InstanceWaitTimeoutError) Error() string { + return "timeout waiting for instance " + string(e.Instance.CloudID) + + " to reach status " + string(e.Desired) + + ", last known status: " + string(e.Instance.Status.LifecycleStatus) + + ", last error: " + errString(e.Err) +} + +func (e *InstanceWaitTimeoutError) Unwrap() error { + return e.Err +} + +// InstanceWaitNotFoundError is returned when the instance is not found during wait. +type InstanceWaitNotFoundError struct { + InstanceID CloudProviderInstanceID + Err error +} + +func (e *InstanceWaitNotFoundError) Error() string { + return "instance not found: " + string(e.InstanceID) + ", error: " + errString(e.Err) +} + +func (e *InstanceWaitNotFoundError) Unwrap() error { + return e.Err +} + +func errString(err error) string { + if err == nil { + return "" + } + return err.Error() +} diff --git a/pkg/v1/waiters_test.go b/pkg/v1/waiters_test.go new file mode 100644 index 0000000..3339253 --- /dev/null +++ b/pkg/v1/waiters_test.go @@ -0,0 +1,371 @@ +package v1 + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test suite for WaitForInstanceLifecycleStatus function. +// Some tests are skipped in short mode (-short flag) to speed up development. +// Run without -short flag for full test coverage including longer-running tests. + +// mockCloudInstanceReader implements CloudInstanceReader for testing +type mockCloudInstanceReader struct { + instances map[CloudProviderInstanceID]*Instance + errors map[CloudProviderInstanceID]error + calls map[CloudProviderInstanceID]int + // For status transition testing + statusSequence map[CloudProviderInstanceID][]LifecycleStatus + sequenceIndex map[CloudProviderInstanceID]int +} + +func newMockCloudInstanceReader() *mockCloudInstanceReader { + return &mockCloudInstanceReader{ + instances: make(map[CloudProviderInstanceID]*Instance), + errors: make(map[CloudProviderInstanceID]error), + calls: make(map[CloudProviderInstanceID]int), + statusSequence: make(map[CloudProviderInstanceID][]LifecycleStatus), + sequenceIndex: make(map[CloudProviderInstanceID]int), + } +} + +func (m *mockCloudInstanceReader) GetInstance(_ context.Context, id CloudProviderInstanceID) (*Instance, error) { + m.calls[id]++ + + if err, exists := m.errors[id]; exists { + return nil, err + } + + // Check if we have a status sequence for this instance + if sequence, exists := m.statusSequence[id]; exists { + index := m.sequenceIndex[id] + if index < len(sequence) { + status := sequence[index] + m.sequenceIndex[id]++ + return &Instance{ + CloudID: id, + Status: Status{ + LifecycleStatus: status, + }, + }, nil + } + // If we've exhausted the sequence, use the last status + if len(sequence) > 0 { + lastStatus := sequence[len(sequence)-1] + return &Instance{ + CloudID: id, + Status: Status{ + LifecycleStatus: lastStatus, + }, + }, nil + } + } + + if instance, exists := m.instances[id]; exists { + return instance, nil + } + + return nil, errors.New("instance not found") +} + +func (m *mockCloudInstanceReader) ListInstances(_ context.Context, _ ListInstancesArgs) ([]Instance, error) { + // Not used in tests, but required by interface + return nil, nil +} + +// setStatusSequence sets up a sequence of statuses that will be returned for an instance +func (m *mockCloudInstanceReader) setStatusSequence(instanceID CloudProviderInstanceID, sequence []LifecycleStatus) { + m.statusSequence[instanceID] = sequence + m.sequenceIndex[instanceID] = 0 +} + +func TestWaitForInstanceLifecycleStatus_Success(t *testing.T) { + client := newMockCloudInstanceReader() + instanceID := CloudProviderInstanceID("test-instance-123") + + // Set up instance that is already in running status + instance := &Instance{ + CloudID: instanceID, + Status: Status{ + LifecycleStatus: LifecycleStatusRunning, + }, + } + client.instances[instanceID] = instance + + ctx := context.Background() + timeout := 5 * time.Second + + _, err := WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, timeout) + + assert.NoError(t, err) + assert.Equal(t, 1, client.calls[instanceID]) +} + +func TestWaitForInstanceLifecycleStatus_StatusTransition(t *testing.T) { + if testing.Short() { + t.Skip("skipping status transition test in short mode") + } + + client := newMockCloudInstanceReader() + instanceID := CloudProviderInstanceID("test-instance-456") + + // Set up a status sequence: pending -> pending -> running + client.setStatusSequence(instanceID, []LifecycleStatus{ + LifecycleStatusPending, + LifecycleStatusPending, + LifecycleStatusRunning, + }) + + instance := &Instance{ + CloudID: instanceID, + Status: Status{ + LifecycleStatus: LifecycleStatusPending, + }, + } + + ctx := context.Background() + timeout := 10 * time.Second + + _, err := WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, timeout) + + assert.NoError(t, err) + assert.GreaterOrEqual(t, client.calls[instanceID], 3, "Expected at least 3 calls to GetInstance") +} + +func TestWaitForInstanceLifecycleStatus_Timeout(t *testing.T) { + client := newMockCloudInstanceReader() + instanceID := CloudProviderInstanceID("test-instance-timeout") + + // Instance stays in pending status + instance := &Instance{ + CloudID: instanceID, + Status: Status{ + LifecycleStatus: LifecycleStatusPending, + }, + } + client.instances[instanceID] = instance + + ctx := context.Background() + timeout := 100 * time.Millisecond // Short timeout for testing + + _, err := WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, timeout) + + require.Error(t, err) + + timeoutErr, ok := err.(*InstanceWaitTimeoutError) + require.True(t, ok, "Expected InstanceWaitTimeoutError, got: %T", err) + + assert.Equal(t, LifecycleStatusRunning, timeoutErr.Desired) + assert.Equal(t, instanceID, timeoutErr.Instance.CloudID) + // With a 100ms timeout, we might not get any calls due to the 2-second ticker + // The function will timeout before the first ticker fires + assert.GreaterOrEqual(t, client.calls[instanceID], 0, "Expected 0 or more calls to GetInstance") +} + +func TestWaitForInstanceLifecycleStatus_ContextCancellation(t *testing.T) { + client := newMockCloudInstanceReader() + instanceID := CloudProviderInstanceID("test-instance-context") + + instance := &Instance{ + CloudID: instanceID, + Status: Status{ + LifecycleStatus: LifecycleStatusPending, + }, + } + client.instances[instanceID] = instance + + ctx, cancel := context.WithCancel(context.Background()) + timeout := 5 * time.Second + + // Cancel context after a short delay + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + _, err := WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, timeout) + + require.Error(t, err) + assert.Equal(t, context.Canceled, err) +} + +func TestWaitForInstanceLifecycleStatus_InstanceNotFound(t *testing.T) { + client := newMockCloudInstanceReader() + instanceID := CloudProviderInstanceID("test-instance-notfound") + + // Set up client to return "not found" error + client.errors[instanceID] = errors.New("instance not found") + + instance := &Instance{ + CloudID: instanceID, + Status: Status{ + LifecycleStatus: LifecycleStatusPending, + }, + } + + ctx := context.Background() + timeout := 5 * time.Second + + _, err := WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, timeout) + + require.Error(t, err) + + notFoundErr, ok := err.(*InstanceWaitNotFoundError) + require.True(t, ok, "Expected InstanceWaitNotFoundError, got: %T", err) + + assert.Equal(t, instanceID, notFoundErr.InstanceID) + assert.Equal(t, 1, client.calls[instanceID]) +} + +func TestWaitForInstanceLifecycleStatus_ErrorString(t *testing.T) { + // Test error string formatting + instance := Instance{ + CloudID: "test-instance", + Status: Status{ + LifecycleStatus: LifecycleStatusPending, + }, + } + + timeoutErr := &InstanceWaitTimeoutError{ + Instance: &instance, + Desired: LifecycleStatusRunning, + Err: errors.New("test error"), + } + + errorStr := timeoutErr.Error() + assert.Contains(t, errorStr, "timeout waiting for instance test-instance to reach status running") + assert.Contains(t, errorStr, "test error") + + notFoundErr := &InstanceWaitNotFoundError{ + InstanceID: "test-instance", + Err: errors.New("not found"), + } + + errorStr = notFoundErr.Error() + assert.Contains(t, errorStr, "instance not found: test-instance") + assert.Contains(t, errorStr, "not found") +} + +func TestWaitForInstanceLifecycleStatus_ErrorUnwrap(t *testing.T) { + originalErr := errors.New("original error") + + timeoutErr := &InstanceWaitTimeoutError{ + Instance: &Instance{CloudID: "test"}, + Desired: LifecycleStatusRunning, + Err: originalErr, + } + + assert.Equal(t, originalErr, timeoutErr.Unwrap()) + + notFoundErr := &InstanceWaitNotFoundError{ + InstanceID: "test", + Err: originalErr, + } + + assert.Equal(t, originalErr, notFoundErr.Unwrap()) +} + +func TestWaitForInstanceLifecycleStatus_ErrString(t *testing.T) { + // Test errString helper function + assert.Equal(t, "", errString(nil)) + + testErr := errors.New("test error") + assert.Equal(t, "test error", errString(testErr)) +} + +func TestWaitForInstanceLifecycleStatus_AllLifecycleStatuses(t *testing.T) { + if testing.Short() { + t.Skip("skipping all lifecycle statuses test in short mode") + } + + // Test that the function works with all possible lifecycle statuses + statuses := []LifecycleStatus{ + LifecycleStatusPending, + LifecycleStatusRunning, + LifecycleStatusStopping, + LifecycleStatusStopped, + LifecycleStatusSuspending, + LifecycleStatusSuspended, + LifecycleStatusTerminating, + LifecycleStatusTerminated, + LifecycleStatusFailed, + } + + for _, status := range statuses { + t.Run(string(status), func(t *testing.T) { + client := newMockCloudInstanceReader() + instanceID := CloudProviderInstanceID("test-instance-" + string(status)) + + instance := &Instance{ + CloudID: instanceID, + Status: Status{ + LifecycleStatus: status, + }, + } + client.instances[instanceID] = instance + + ctx := context.Background() + timeout := 3 * time.Second // Give enough time for the ticker to fire + + _, err := WaitForInstanceLifecycleStatus(ctx, client, instance, status, timeout) + + assert.NoError(t, err) + assert.Equal(t, 1, client.calls[instanceID]) + }) + } +} + +func TestWaitForInstanceLifecycleStatus_TimeoutWithLastInstance(t *testing.T) { + client := newMockCloudInstanceReader() + instanceID := CloudProviderInstanceID("test-instance-timeout-last") + + // Set up instance that stays in pending status + instance := &Instance{ + CloudID: instanceID, + Status: Status{ + LifecycleStatus: LifecycleStatusPending, + }, + } + client.instances[instanceID] = instance + + ctx := context.Background() + timeout := 100 * time.Millisecond + + _, err := WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, timeout) + + require.Error(t, err) + + timeoutErr, ok := err.(*InstanceWaitTimeoutError) + require.True(t, ok) + + // Verify that the timeout error contains the last known instance + assert.Equal(t, instanceID, timeoutErr.Instance.CloudID) + assert.Equal(t, LifecycleStatusPending, timeoutErr.Instance.Status.LifecycleStatus) +} + +// Benchmark test for performance +func BenchmarkWaitForInstanceLifecycleStatus(b *testing.B) { + client := newMockCloudInstanceReader() + instanceID := CloudProviderInstanceID("benchmark-instance") + + instance := &Instance{ + CloudID: instanceID, + Status: Status{ + LifecycleStatus: LifecycleStatusRunning, + }, + } + client.instances[instanceID] = instance + + ctx := context.Background() + timeout := 1 * time.Second + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, timeout) + } +}