diff --git a/coverage/coverage.html b/coverage/coverage.html index 9f8afc4..cdb4acc 100644 --- a/coverage/coverage.html +++ b/coverage/coverage.html @@ -59,15 +59,15 @@ - + - + - + - + - + @@ -214,28 +214,25 @@ import ( "context" - "fmt" - - "github.com/rebelopsio/duet/internal/config/ssh" ) -type Executor struct { - sshClient *ssh.Client +// Executor defines the interface for executing commands +type ExecutorInterface interface { + Execute(ctx context.Context, command string) (string, error) } -func NewExecutor(sshConfig *ssh.Config) (*Executor, error) { - client, err := ssh.NewClient(sshConfig) - if err != nil { - return nil, fmt.Errorf("failed to create SSH client: %w", err) - } - - return &Executor{ - sshClient: client, - }, nil +type Executor struct { + executor ExecutorInterface } +func NewExecutor(executor ExecutorInterface) *Executor { + return &Executor{ + executor: executor, + } +} + func (e *Executor) Execute(ctx context.Context, command string) (string, error) { - return e.sshClient.Execute(ctx, command) + return e.executor.Execute(ctx, command) } @@ -245,127 +242,230 @@ "context" "fmt" "io" + "net" + "strings" + "sync" + "time" "golang.org/x/crypto/ssh" ) +// Config holds the SSH client configuration type Config struct { Host string User string PrivateKey string Port int + Timeout time.Duration } +// Client represents an SSH client type Client struct { config *Config client *ssh.Client } -func NewClient(config *Config) (*Client, error) { - signer, err := ssh.ParsePrivateKey([]byte(config.PrivateKey)) +// NewClient creates a new SSH client with timeouts +func NewClient(config *Config) (*Client, error) { + if config.Timeout == 0 { + config.Timeout = 30 * time.Second + } + + // Parse the private key + signer, err := ssh.ParsePrivateKey([]byte(config.PrivateKey)) if err != nil { return nil, fmt.Errorf("failed to parse private key: %w", err) } - sshConfig := &ssh.ClientConfig{ + // Create SSH client config + sshConfig := &ssh.ClientConfig{ User: config.User, Auth: []ssh.AuthMethod{ ssh.PublicKeys(signer), }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Note: In production, use proper host key verification + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: config.Timeout, } + // Create a connection with timeout addr := fmt.Sprintf("%s:%d", config.Host, config.Port) - client, err := ssh.Dial("tcp", addr, sshConfig) + conn, err := net.DialTimeout("tcp", addr, config.Timeout) if err != nil { - return nil, fmt.Errorf("failed to dial SSH: %w", err) + return nil, fmt.Errorf("failed to connect: %w", err) } - return &Client{ + // Set connection deadline + if err := conn.SetDeadline(time.Now().Add(config.Timeout)); err != nil { + closeErr := conn.Close() + if closeErr != nil { + return nil, fmt.Errorf("failed to set connection deadline and close connection: %v, close error: %w", err, closeErr) + } + return nil, fmt.Errorf("failed to set connection deadline: %w", err) + } + + // Create new SSH client connection + c, chans, reqs, err := ssh.NewClientConn(conn.(*net.TCPConn), addr, sshConfig) + if err != nil { + closeErr := conn.Close() + if closeErr != nil { + return nil, fmt.Errorf("failed to create SSH connection and close connection: %v, close error: %w", err, closeErr) + } + return nil, fmt.Errorf("failed to create SSH connection: %w", err) + } + + // Clear the deadline after successful handshake + if err := conn.SetDeadline(time.Time{}); err != nil { + closeErr := c.Close() + if closeErr != nil { + return nil, fmt.Errorf("failed to clear connection deadline and close client: %v, close error: %w", err, closeErr) + } + return nil, fmt.Errorf("failed to clear connection deadline: %w", err) + } + + client := ssh.NewClient(c, chans, reqs) + + return &Client{ config: config, client: client, }, nil } -func (c *Client) Close() error { - if c.client != nil { +// Close closes the SSH connection +func (c *Client) Close() error { + if c.client != nil { return c.client.Close() } return nil } -func (c *Client) Execute(ctx context.Context, command string) (output string, err error) { +// ValidateConnection tests if the SSH connection is working +func (c *Client) ValidateConnection() error { session, err := c.client.NewSession() if err != nil { - return "", fmt.Errorf("failed to create session: %w", err) + return fmt.Errorf("failed to create session: %w", err) } + defer func() { + if err := session.Close(); err != nil { + fmt.Printf("error closing session: %v\n", err) + } + }() + return nil +} - defer func() { - closeErr := session.Close() - if err == nil { - // If there was no error from the command, return any close error - err = closeErr - } else if closeErr != nil { - // If there were both command and close errors, combine them - err = fmt.Errorf("command error: %w; close error: %v", err, closeErr) +// Execute runs a command over SSH with context for cancellation +func (c *Client) Execute(ctx context.Context, command string) (string, error) { + session, err := c.client.NewSession() + if err != nil { + return "", fmt.Errorf("failed to create session: %w", err) + } + defer func() { + if err := session.Close(); err != nil && !isClosedError(err) { + fmt.Printf("error closing session: %v\n", err) } }() - // Set up pipes for stdout and stderr - var stdout, stderr io.Reader - stdout, err = session.StdoutPipe() + // Set up pipes for output + stdout, err := session.StdoutPipe() if err != nil { return "", fmt.Errorf("failed to create stdout pipe: %w", err) } - stderr, err = session.StderrPipe() + stderr, err := session.StderrPipe() if err != nil { return "", fmt.Errorf("failed to create stderr pipe: %w", err) } - // Start the command - if err := session.Start(command); err != nil { - return "", fmt.Errorf("failed to start command: %w", err) - } + type commandResult struct { + output string + err error + stderrData string + } + resultChan := make(chan commandResult, 1) - // Read output - outputBytes, err := io.ReadAll(stdout) - if err != nil { - return "", fmt.Errorf("failed to read stdout: %w", err) - } + go func() { + // Start the command + if err := session.Start(command); err != nil { + resultChan <- commandResult{err: fmt.Errorf("failed to start command: %w", err)} + return + } - // Check for errors - errOutput, err := io.ReadAll(stderr) - if err != nil { - return "", fmt.Errorf("failed to read stderr: %w", err) - } + // Read stdout and stderr concurrently + var stdoutData, stderrData []byte + var stdoutErr, stderrErr error + var wg sync.WaitGroup + + wg.Add(2) + go func() { + defer wg.Done() + stdoutData, stdoutErr = io.ReadAll(stdout) + }() + + go func() { + defer wg.Done() + stderrData, stderrErr = io.ReadAll(stderr) + }() + + // Wait for all readers to complete + wg.Wait() + + // Handle any read errors + if stdoutErr != nil { + resultChan <- commandResult{err: fmt.Errorf("failed to read stdout: %w", stdoutErr)} + return + } + if stderrErr != nil { + resultChan <- commandResult{err: fmt.Errorf("failed to read stderr: %w", stderrErr)} + return + } - // Wait for the command to complete - if err := session.Wait(); err != nil { - return "", fmt.Errorf("command failed: %s: %w", string(errOutput), err) - } + // Wait for the command to complete + err := session.Wait() + if err != nil { + resultChan <- commandResult{ + err: fmt.Errorf("command failed: %w", err), + stderrData: string(stderrData), + } + return + } - return string(outputBytes), nil -} + resultChan <- commandResult{ + output: string(stdoutData), + stderrData: string(stderrData), + } + }() -// ValidateConnection tests the SSH connection without executing a command -func (c *Client) ValidateConnection() (err error) { - session, err := c.client.NewSession() - if err != nil { - return fmt.Errorf("failed to create session: %w", err) - } + // Wait for completion or cancellation + select { + case result := <-resultChan: + if result.err != nil { + if result.stderrData != "" { + return "", fmt.Errorf("%w: %s", result.err, result.stderrData) + } + return "", result.err + } + return result.output, nil + + case <-ctx.Done(): + if err := session.Signal(ssh.SIGTERM); err != nil && !isClosedError(err) { + fmt.Printf("error sending SIGTERM: %v\n", err) + } + return "", ctx.Err() - defer func() { - closeErr := session.Close() - if err == nil { - // If validation succeeded, return any close error - err = closeErr - } else if closeErr != nil { - // If both validation and close failed, combine the errors - err = fmt.Errorf("validation error: %w; close error: %v", err, closeErr) + case <-time.After(c.config.Timeout): + if err := session.Signal(ssh.SIGTERM); err != nil && !isClosedError(err) { + fmt.Printf("error sending SIGTERM: %v\n", err) } - }() + return "", context.DeadlineExceeded + } +} - return nil +func isClosedError(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "use of closed network connection") || + strings.Contains(err.Error(), "connection reset by peer") || + strings.Contains(err.Error(), "closed network connection") || + strings.Contains(err.Error(), "EOF") } @@ -373,23 +473,27 @@ import ( "context" - - "github.com/rebelopsio/duet/internal/config/executor" + "fmt" ) +// ExecutorInterface defines the required methods for command execution +type ExecutorInterface interface { + Execute(ctx context.Context, command string) (string, error) +} + type PackageManager struct { - executor *executor.Executor + executor ExecutorInterface } -func NewPackageManager(executor *executor.Executor) *PackageManager { +func NewPackageManager(executor ExecutorInterface) *PackageManager { return &PackageManager{ executor: executor, } } -func (pm *PackageManager) Install(ctx context.Context, packageName string) error { - // Implementation - return nil +func (pm *PackageManager) Install(ctx context.Context, packageName string) error { + _, err := pm.executor.Execute(ctx, fmt.Sprintf("apt-get install -y %s", packageName)) + return err } @@ -405,14 +509,14 @@ state *lua.LState } -func NewEngine() *Engine { +func NewEngine() *Engine { return &Engine{ state: lua.NewState(), } } -func (e *Engine) Close() { - if e.state != nil { +func (e *Engine) Close() { + if e.state != nil { e.state.Close() } } @@ -421,13 +525,13 @@ return e.state.DoFile(filename) } -func (e *Engine) CallFunction(name string, args ...lua.LValue) (lua.LValue, error) { +func (e *Engine) CallFunction(name string, args ...lua.LValue) (lua.LValue, error) { fn := e.state.GetGlobal(name) - if fn == lua.LNil { + if fn == lua.LNil { return nil, fmt.Errorf("function %s not found", name) } - err := e.state.CallByParam(lua.P{ + err := e.state.CallByParam(lua.P{ Fn: fn, NRet: 1, Protect: true, @@ -436,7 +540,7 @@ return nil, fmt.Errorf("error calling function %s: %w", name, err) } - ret := e.state.Get(-1) + ret := e.state.Get(-1) e.state.Pop(1) return ret, nil } @@ -468,47 +572,47 @@ ConfigApplied bool } -func NewStore(dbPath string) (*Store, error) { +func NewStore(dbPath string) (*Store, error) { db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } - if err := db.AutoMigrate(&Resource{}); err != nil { + if err := db.AutoMigrate(&Resource{}); err != nil { return nil, fmt.Errorf("failed to migrate database: %w", err) } - return &Store{db: db}, nil + return &Store{db: db}, nil } // GetResources retrieves all resources from the store -func (s *Store) GetResources(ctx context.Context) ([]Resource, error) { +func (s *Store) GetResources(ctx context.Context) ([]Resource, error) { var resources []Resource result := s.db.WithContext(ctx).Find(&resources) if result.Error != nil { return nil, fmt.Errorf("failed to get resources: %w", result.Error) } - return resources, nil + return resources, nil } // SaveResource saves a resource to the store -func (s *Store) SaveResource(ctx context.Context, resource *Resource) error { +func (s *Store) SaveResource(ctx context.Context, resource *Resource) error { result := s.db.WithContext(ctx).Save(resource) return result.Error } // GetResource retrieves a single resource by ID -func (s *Store) GetResource(ctx context.Context, id string) (*Resource, error) { +func (s *Store) GetResource(ctx context.Context, id string) (*Resource, error) { var resource Resource result := s.db.WithContext(ctx).First(&resource, "id = ?", id) - if result.Error != nil { + if result.Error != nil { return nil, result.Error } - return &resource, nil + return &resource, nil } // DeleteResource removes a resource from the store -func (s *Store) DeleteResource(ctx context.Context, id string) error { +func (s *Store) DeleteResource(ctx context.Context, id string) error { result := s.db.WithContext(ctx).Delete(&Resource{}, "id = ?", id) return result.Error } @@ -519,12 +623,12 @@ import ( "context" - "github.com/rebelopsio/duet/internal/core/state" "github.com/rebelopsio/duet/internal/iac/provider" + "github.com/rebelopsio/duet/pkg/types" ) type Change struct { - Resource provider.Resource + Resource types.Resource Config map[string]interface{} Type string Provider string @@ -535,23 +639,21 @@ } type Planner struct { - store *state.Store providers map[string]provider.Provider } -func NewPlanner(store *state.Store) *Planner { +func NewPlanner() *Planner { return &Planner{ - store: store, providers: make(map[string]provider.Provider), } } -func (p *Planner) RegisterProvider(provider provider.Provider) { +func (p *Planner) RegisterProvider(provider provider.Provider) { p.providers[provider.Name()] = provider } -func (p *Planner) CreatePlan(ctx context.Context, config map[string]interface{}) (*Plan, error) { - // Implementation +func (p *Planner) CreatePlan(ctx context.Context, config map[string]interface{}) (*Plan, error) { + // Implementation will go here return &Plan{}, nil } diff --git a/internal/config/executor/executor.go b/internal/config/executor/executor.go index 75449db..0f4e04d 100644 --- a/internal/config/executor/executor.go +++ b/internal/config/executor/executor.go @@ -2,26 +2,23 @@ package executor import ( "context" - "fmt" - - "github.com/rebelopsio/duet/internal/config/ssh" ) -type Executor struct { - sshClient *ssh.Client +// Executor defines the interface for executing commands +type ExecutorInterface interface { + Execute(ctx context.Context, command string) (string, error) } -func NewExecutor(sshConfig *ssh.Config) (*Executor, error) { - client, err := ssh.NewClient(sshConfig) - if err != nil { - return nil, fmt.Errorf("failed to create SSH client: %w", err) - } +type Executor struct { + executor ExecutorInterface +} +func NewExecutor(executor ExecutorInterface) *Executor { return &Executor{ - sshClient: client, - }, nil + executor: executor, + } } func (e *Executor) Execute(ctx context.Context, command string) (string, error) { - return e.sshClient.Execute(ctx, command) + return e.executor.Execute(ctx, command) } diff --git a/internal/config/executor/executor_test.go b/internal/config/executor/executor_test.go new file mode 100644 index 0000000..d649899 --- /dev/null +++ b/internal/config/executor/executor_test.go @@ -0,0 +1,38 @@ +package executor + +import ( + "context" + "testing" +) + +// mockExecutor implements the necessary methods for testing +type mockExecutor struct { + executeFunc func(ctx context.Context, command string) (string, error) +} + +func (m *mockExecutor) Execute(ctx context.Context, command string) (string, error) { + if m.executeFunc != nil { + return m.executeFunc(ctx, command) + } + return "", nil +} + +func TestExecutor(t *testing.T) { + t.Run("ExecuteCommand", func(t *testing.T) { + expectedOutput := "command output" + executor := &mockExecutor{ + executeFunc: func(ctx context.Context, command string) (string, error) { + return expectedOutput, nil + }, + } + + output, err := executor.Execute(context.Background(), "test command") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if output != expectedOutput { + t.Errorf("Expected %q, got %q", expectedOutput, output) + } + }) +} diff --git a/internal/config/ssh/client.go b/internal/config/ssh/client.go index 869360c..9e9668b 100644 --- a/internal/config/ssh/client.go +++ b/internal/config/ssh/client.go @@ -4,48 +4,95 @@ import ( "context" "fmt" "io" + "net" + "strings" + "sync" + "time" "golang.org/x/crypto/ssh" ) +// Config holds the SSH client configuration type Config struct { Host string User string PrivateKey string Port int + Timeout time.Duration } +// Client represents an SSH client type Client struct { config *Config client *ssh.Client } +// NewClient creates a new SSH client with timeouts func NewClient(config *Config) (*Client, error) { + if config.Timeout == 0 { + config.Timeout = 30 * time.Second + } + + // Parse the private key signer, err := ssh.ParsePrivateKey([]byte(config.PrivateKey)) if err != nil { return nil, fmt.Errorf("failed to parse private key: %w", err) } + // Create SSH client config sshConfig := &ssh.ClientConfig{ User: config.User, Auth: []ssh.AuthMethod{ ssh.PublicKeys(signer), }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Note: In production, use proper host key verification + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: config.Timeout, } + // Create a connection with timeout addr := fmt.Sprintf("%s:%d", config.Host, config.Port) - client, err := ssh.Dial("tcp", addr, sshConfig) + conn, err := net.DialTimeout("tcp", addr, config.Timeout) if err != nil { - return nil, fmt.Errorf("failed to dial SSH: %w", err) + return nil, fmt.Errorf("failed to connect: %w", err) } + // Set connection deadline + if err := conn.SetDeadline(time.Now().Add(config.Timeout)); err != nil { + closeErr := conn.Close() + if closeErr != nil { + return nil, fmt.Errorf("failed to set connection deadline and close connection: %v, close error: %w", err, closeErr) + } + return nil, fmt.Errorf("failed to set connection deadline: %w", err) + } + + // Create new SSH client connection + c, chans, reqs, err := ssh.NewClientConn(conn.(*net.TCPConn), addr, sshConfig) + if err != nil { + closeErr := conn.Close() + if closeErr != nil { + return nil, fmt.Errorf("failed to create SSH connection and close connection: %v, close error: %w", err, closeErr) + } + return nil, fmt.Errorf("failed to create SSH connection: %w", err) + } + + // Clear the deadline after successful handshake + if err := conn.SetDeadline(time.Time{}); err != nil { + closeErr := c.Close() + if closeErr != nil { + return nil, fmt.Errorf("failed to clear connection deadline and close client: %v, close error: %w", err, closeErr) + } + return nil, fmt.Errorf("failed to clear connection deadline: %w", err) + } + + client := ssh.NewClient(c, chans, reqs) + return &Client{ config: config, client: client, }, nil } +// Close closes the SSH connection func (c *Client) Close() error { if c.client != nil { return c.client.Close() @@ -53,76 +100,132 @@ func (c *Client) Close() error { return nil } -func (c *Client) Execute(ctx context.Context, command string) (output string, err error) { +// ValidateConnection tests if the SSH connection is working +func (c *Client) ValidateConnection() error { session, err := c.client.NewSession() if err != nil { - return "", fmt.Errorf("failed to create session: %w", err) + return fmt.Errorf("failed to create session: %w", err) } + defer func() { + if err := session.Close(); err != nil { + fmt.Printf("error closing session: %v\n", err) + } + }() + return nil +} +// Execute runs a command over SSH with context for cancellation +func (c *Client) Execute(ctx context.Context, command string) (string, error) { + session, err := c.client.NewSession() + if err != nil { + return "", fmt.Errorf("failed to create session: %w", err) + } defer func() { - closeErr := session.Close() - if err == nil { - // If there was no error from the command, return any close error - err = closeErr - } else if closeErr != nil { - // If there were both command and close errors, combine them - err = fmt.Errorf("command error: %w; close error: %v", err, closeErr) + if err := session.Close(); err != nil && !isClosedError(err) { + fmt.Printf("error closing session: %v\n", err) } }() - // Set up pipes for stdout and stderr - var stdout, stderr io.Reader - stdout, err = session.StdoutPipe() + // Set up pipes for output + stdout, err := session.StdoutPipe() if err != nil { return "", fmt.Errorf("failed to create stdout pipe: %w", err) } - stderr, err = session.StderrPipe() + stderr, err := session.StderrPipe() if err != nil { return "", fmt.Errorf("failed to create stderr pipe: %w", err) } - // Start the command - if err := session.Start(command); err != nil { - return "", fmt.Errorf("failed to start command: %w", err) + type commandResult struct { + output string + err error + stderrData string } + resultChan := make(chan commandResult, 1) - // Read output - outputBytes, err := io.ReadAll(stdout) - if err != nil { - return "", fmt.Errorf("failed to read stdout: %w", err) - } + go func() { + // Start the command + if err := session.Start(command); err != nil { + resultChan <- commandResult{err: fmt.Errorf("failed to start command: %w", err)} + return + } - // Check for errors - errOutput, err := io.ReadAll(stderr) - if err != nil { - return "", fmt.Errorf("failed to read stderr: %w", err) - } + // Read stdout and stderr concurrently + var stdoutData, stderrData []byte + var stdoutErr, stderrErr error + var wg sync.WaitGroup + + wg.Add(2) + go func() { + defer wg.Done() + stdoutData, stdoutErr = io.ReadAll(stdout) + }() + + go func() { + defer wg.Done() + stderrData, stderrErr = io.ReadAll(stderr) + }() + + // Wait for all readers to complete + wg.Wait() + + // Handle any read errors + if stdoutErr != nil { + resultChan <- commandResult{err: fmt.Errorf("failed to read stdout: %w", stdoutErr)} + return + } + if stderrErr != nil { + resultChan <- commandResult{err: fmt.Errorf("failed to read stderr: %w", stderrErr)} + return + } - // Wait for the command to complete - if err := session.Wait(); err != nil { - return "", fmt.Errorf("command failed: %s: %w", string(errOutput), err) - } + // Wait for the command to complete + err := session.Wait() + if err != nil { + resultChan <- commandResult{ + err: fmt.Errorf("command failed: %w", err), + stderrData: string(stderrData), + } + return + } - return string(outputBytes), nil -} + resultChan <- commandResult{ + output: string(stdoutData), + stderrData: string(stderrData), + } + }() -// ValidateConnection tests the SSH connection without executing a command -func (c *Client) ValidateConnection() (err error) { - session, err := c.client.NewSession() - if err != nil { - return fmt.Errorf("failed to create session: %w", err) - } + // Wait for completion or cancellation + select { + case result := <-resultChan: + if result.err != nil { + if result.stderrData != "" { + return "", fmt.Errorf("%w: %s", result.err, result.stderrData) + } + return "", result.err + } + return result.output, nil - defer func() { - closeErr := session.Close() - if err == nil { - // If validation succeeded, return any close error - err = closeErr - } else if closeErr != nil { - // If both validation and close failed, combine the errors - err = fmt.Errorf("validation error: %w; close error: %v", err, closeErr) + case <-ctx.Done(): + if err := session.Signal(ssh.SIGTERM); err != nil && !isClosedError(err) { + fmt.Printf("error sending SIGTERM: %v\n", err) } - }() + return "", ctx.Err() - return nil + case <-time.After(c.config.Timeout): + if err := session.Signal(ssh.SIGTERM); err != nil && !isClosedError(err) { + fmt.Printf("error sending SIGTERM: %v\n", err) + } + return "", context.DeadlineExceeded + } +} + +func isClosedError(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "use of closed network connection") || + strings.Contains(err.Error(), "connection reset by peer") || + strings.Contains(err.Error(), "closed network connection") || + strings.Contains(err.Error(), "EOF") } diff --git a/internal/config/ssh/client_test.go b/internal/config/ssh/client_test.go index b884923..9d16933 100644 --- a/internal/config/ssh/client_test.go +++ b/internal/config/ssh/client_test.go @@ -1,165 +1,492 @@ package ssh import ( + "bytes" "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "fmt" "io" "net" + "strings" + "sync" "testing" + "time" "golang.org/x/crypto/ssh" ) -// mockSSHServer simulates an SSH server for testing +type keyPair struct { + PublicKey ssh.PublicKey + PrivateKey string +} + type mockSSHServer struct { - listener net.Listener - config *ssh.ServerConfig + listener net.Listener + ctx context.Context + config *ssh.ServerConfig + ready chan struct{} + done chan struct{} + activeConns map[string]net.Conn + t *testing.T + cancel context.CancelFunc + wg sync.WaitGroup + activesMutex sync.RWMutex +} + +// generateTestKey generates a test RSA key pair for testing +func generateTestKey() (*keyPair, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %w", err) + } + + // Convert private key to PEM format + privateKeyPEM := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + } + + // Generate SSH public key + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to create public key: %w", err) + } + + return &keyPair{ + PrivateKey: string(pem.EncodeToMemory(privateKeyPEM)), + PublicKey: publicKey, + }, nil } -func newMockSSHServer(t *testing.T) (*mockSSHServer, error) { +func newMockSSHServer(t *testing.T, keys *keyPair) (*mockSSHServer, error) { + ctx, cancel := context.WithCancel(context.Background()) + config := &ssh.ServerConfig{ - PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - return nil, nil + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if key.Type() == keys.PublicKey.Type() && bytes.Equal(key.Marshal(), keys.PublicKey.Marshal()) { + return &ssh.Permissions{}, nil + } + return nil, fmt.Errorf("unknown public key") }, } - privateKey, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + signer, err := ssh.ParsePrivateKey([]byte(keys.PrivateKey)) if err != nil { + cancel() return nil, fmt.Errorf("failed to parse private key: %w", err) } - config.AddHostKey(privateKey) + config.AddHostKey(signer) listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { + cancel() return nil, fmt.Errorf("failed to listen: %w", err) } server := &mockSSHServer{ - listener: listener, - config: config, + listener: listener, + config: config, + ready: make(chan struct{}), + done: make(chan struct{}), + activeConns: make(map[string]net.Conn), + t: t, + ctx: ctx, + cancel: cancel, } - go server.serve(t) - return server, nil + server.wg.Add(1) + go server.acceptConnections() + + // Wait for server to be ready + select { + case <-server.ready: + return server, nil + case <-time.After(2 * time.Second): + cancel() + if err := listener.Close(); err != nil { + t.Logf("Failed to close listener: %v", err) + } + return nil, fmt.Errorf("timeout waiting for server to be ready") + } } -func (s *mockSSHServer) serve(t *testing.T) { +func (s *mockSSHServer) acceptConnections() { + defer s.wg.Done() + defer close(s.done) + defer s.cancel() + + // Signal that we're ready to accept connections + close(s.ready) + for { + if err := s.listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second)); err != nil { + s.t.Logf("Failed to set accept deadline: %v", err) + return + } + conn, err := s.listener.Accept() if err != nil { + if isTimeout(err) { + select { + case <-s.ctx.Done(): + return + default: + continue + } + } if !isClosedError(err) { - t.Errorf("Failed to accept connection: %v", err) + s.t.Logf("Accept error: %v", err) } return } - _, chans, reqs, err := ssh.NewServerConn(conn, s.config) - if err != nil { - t.Errorf("Failed to handshake: %v", err) - continue - } + s.activesMutex.Lock() + s.activeConns[conn.RemoteAddr().String()] = conn + s.activesMutex.Unlock() - go ssh.DiscardRequests(reqs) - go handleChannels(chans, t) + s.wg.Add(1) + go s.handleConnection(conn) } } -func handleChannels(chans <-chan ssh.NewChannel, t *testing.T) { - for newChannel := range chans { - if newChannel.ChannelType() != "session" { - newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") - continue +func isTimeout(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "timeout") || + strings.Contains(err.Error(), "deadline exceeded") || + err == context.DeadlineExceeded +} + +func (s *mockSSHServer) handleConnection(conn net.Conn) { + defer s.wg.Done() + defer func() { + s.activesMutex.Lock() + delete(s.activeConns, conn.RemoteAddr().String()) + s.activesMutex.Unlock() + if err := conn.Close(); err != nil && !isClosedError(err) { + s.t.Logf("Connection close error: %v", err) } + }() - channel, requests, err := newChannel.Accept() - if err != nil { - t.Errorf("Failed to accept channel: %v", err) - continue + sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.config) + if err != nil { + if !isClosedError(err) { + s.t.Logf("SSH handshake error: %v", err) } + return + } + + go func() { + <-s.ctx.Done() + if err := sshConn.Close(); err != nil && !isClosedError(err) { + s.t.Logf("SSH connection close error: %v", err) + } + }() + + go ssh.DiscardRequests(reqs) + s.handleChannels(chans) +} + +func (s *mockSSHServer) handleRequests(channel ssh.Channel, requests <-chan *ssh.Request) { + defer func() { + if err := channel.Close(); err != nil && !isClosedError(err) { + s.t.Logf("Channel close error: %v", err) + } + }() - go func(in <-chan *ssh.Request) { - for req := range in { - switch req.Type { - case "exec": - payload := struct{ Command string }{} - ssh.Unmarshal(req.Payload, &payload) + for { + select { + case <-s.ctx.Done(): + return + case req, ok := <-requests: + if !ok { + return + } + switch req.Type { + case "exec": + exitStatus := make([]byte, 4) - if payload.Command == "echo test" { - io.WriteString(channel, "test\n") + payload := struct{ Command string }{} + if err := ssh.Unmarshal(req.Payload, &payload); err != nil { + s.t.Logf("Unmarshal error: %v", err) + if err := req.Reply(false, nil); err != nil { + s.t.Logf("Reply error: %v", err) } + continue + } - channel.Close() + if err := req.Reply(true, nil); err != nil { + s.t.Logf("Reply error: %v", err) + continue } - req.Reply(true, nil) + + if payload.Command == "echo test" { + if _, err := io.WriteString(channel, "test\n"); err != nil { + s.t.Logf("Write error: %v", err) + } + + // Send exit status + _, err := channel.SendRequest("exit-status", false, exitStatus) + if err != nil { + s.t.Logf("Failed to send exit status: %v", err) + } + + if err := channel.CloseWrite(); err != nil && !isClosedError(err) { + s.t.Logf("CloseWrite error: %v", err) + } + return + } + + if payload.Command == "sleep 10" { + // For the timeout test, we'll block until context is cancelled + select { + case <-s.ctx.Done(): + // Send non-zero exit status for cancelled command + exitStatus[3] = 1 + _, err := channel.SendRequest("exit-status", false, exitStatus) + if err != nil { + s.t.Logf("Failed to send exit status: %v", err) + } + case <-time.After(10 * time.Second): + // Normal completion + _, err := channel.SendRequest("exit-status", false, exitStatus) + if err != nil { + s.t.Logf("Failed to send exit status: %v", err) + } + } + + if err := channel.CloseWrite(); err != nil && !isClosedError(err) { + s.t.Logf("CloseWrite error: %v", err) + } + return + } + + // Unknown command, send error exit status + exitStatus[3] = 1 + _, err := channel.SendRequest("exit-status", false, exitStatus) + if err != nil { + s.t.Logf("Failed to send exit status: %v", err) + } + + if err := channel.CloseWrite(); err != nil && !isClosedError(err) { + s.t.Logf("CloseWrite error: %v", err) + } + return + + default: + if err := req.Reply(false, nil); err != nil { + s.t.Logf("Reply error: %v", err) + } + } + } + } +} + +func (s *mockSSHServer) handleChannels(chans <-chan ssh.NewChannel) { + for { + select { + case <-s.ctx.Done(): + return + case newChannel, ok := <-chans: + if !ok { + return } - }(requests) + go s.handleChannel(newChannel) + } } } -func isClosedError(err error) bool { - return err.Error() == "use of closed network connection" +func (s *mockSSHServer) handleChannel(newChannel ssh.NewChannel) { + if newChannel.ChannelType() != "session" { + if err := newChannel.Reject(ssh.UnknownChannelType, "unknown channel type"); err != nil { + s.t.Logf("Channel reject error: %v", err) + } + return + } + + channel, requests, err := newChannel.Accept() + if err != nil { + s.t.Logf("Channel accept error: %v", err) + return + } + + go s.handleRequests(channel, requests) } -const testPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY----- -b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn -NhAAAAAwEAAQAAAQEAxU4rixQXoahCL2gVoNWswNMFxYEiO0YH9YbB1qh+9nYRYGzEOc0l -... ------END OPENSSH PRIVATE KEY-----` +func (s *mockSSHServer) shutdown() error { + s.cancel() + + if err := s.listener.Close(); err != nil { + return fmt.Errorf("failed to close listener: %w", err) + } + + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-time.After(5 * time.Second): + return fmt.Errorf("timeout waiting for server shutdown") + } +} func TestClient(t *testing.T) { - mockServer, err := newMockSSHServer(t) + if testing.Short() { + t.Skip("Skipping SSH tests in short mode") + } + + // Generate a test key for this test run + keys, err := generateTestKey() + if err != nil { + t.Fatalf("Failed to generate test key: %v", err) + } + + mockServer, err := newMockSSHServer(t, keys) if err != nil { t.Fatalf("Failed to start mock SSH server: %v", err) } - defer mockServer.listener.Close() + + // Use a cleanup function to ensure proper shutdown + cleanup := func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- mockServer.shutdown() + }() + + select { + case err := <-done: + if err != nil { + t.Logf("Server shutdown error: %v", err) + } + case <-ctx.Done(): + t.Log("Server shutdown timed out") + } + } + defer cleanup() serverAddr := mockServer.listener.Addr().String() - host, port, err := net.SplitHostPort(serverAddr) + host, portStr, err := net.SplitHostPort(serverAddr) if err != nil { t.Fatalf("Failed to parse server address: %v", err) } + var port int + if _, err := fmt.Sscanf(portStr, "%d", &port); err != nil { + t.Fatalf("Failed to parse port number: %v", err) + } + config := &Config{ Host: host, + Port: port, User: "test", - PrivateKey: testPrivateKey, - Port: parseInt(port), + PrivateKey: keys.PrivateKey, + Timeout: 2 * time.Second, } - t.Run("Connect", func(t *testing.T) { + // Helper function to create and cleanup client + createClient := func(t *testing.T) (*Client, func()) { client, err := NewClient(config) if err != nil { t.Fatalf("Failed to create client: %v", err) } - defer client.Close() - err = client.ValidateConnection() - if err != nil { - t.Errorf("Failed to validate connection: %v", err) + cleanup := func() { + if err := client.Close(); err != nil && !isClosedError(err) { + t.Logf("Client close error: %v", err) + } + } + + return client, cleanup + } + + t.Run("Connect", func(t *testing.T) { + client, cleanup := createClient(t) + defer cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- client.ValidateConnection() + }() + + select { + case err := <-done: + if err != nil { + t.Errorf("Failed to validate connection: %v", err) + } + case <-ctx.Done(): + t.Error("Connection validation timed out") } }) t.Run("Execute", func(t *testing.T) { - client, err := NewClient(config) - if err != nil { - t.Fatalf("Failed to create client: %v", err) - } - defer client.Close() + client, cleanup := createClient(t) + defer cleanup() - output, err := client.Execute(context.Background(), "echo test") - if err != nil { - t.Fatalf("Failed to execute command: %v", err) - } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan struct { + output string + err error + }, 1) - expected := "test\n" - if output != expected { - t.Errorf("Expected output %q, got %q", expected, output) + go func() { + output, err := client.Execute(ctx, "echo test") + done <- struct { + output string + err error + }{output, err} + }() + + select { + case result := <-done: + if result.err != nil { + t.Fatalf("Failed to execute command: %v", result.err) + } + if result.output != "test\n" { + t.Errorf("Expected output %q, got %q", "test\n", result.output) + } + case <-ctx.Done(): + t.Fatal("Command execution timed out") } }) -} -func parseInt(s string) int { - var port int - fmt.Sscanf(s, "%d", &port) - return port + t.Run("ExecuteWithTimeout", func(t *testing.T) { + client, cleanup := createClient(t) + defer cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + _, err := client.Execute(ctx, "sleep 10") + done <- err + }() + + select { + case err := <-done: + if err == nil { + t.Error("Expected error, got nil") + } else if err != context.DeadlineExceeded { + t.Errorf("Expected context.DeadlineExceeded error, got: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Test timed out waiting for command timeout") + } + }) } diff --git a/internal/config/tasks/package.go b/internal/config/tasks/package.go index 4a5ee04..bec058a 100644 --- a/internal/config/tasks/package.go +++ b/internal/config/tasks/package.go @@ -2,21 +2,25 @@ package tasks import ( "context" - - "github.com/rebelopsio/duet/internal/config/executor" + "fmt" ) +// ExecutorInterface defines the required methods for command execution +type ExecutorInterface interface { + Execute(ctx context.Context, command string) (string, error) +} + type PackageManager struct { - executor *executor.Executor + executor ExecutorInterface } -func NewPackageManager(executor *executor.Executor) *PackageManager { +func NewPackageManager(executor ExecutorInterface) *PackageManager { return &PackageManager{ executor: executor, } } func (pm *PackageManager) Install(ctx context.Context, packageName string) error { - // Implementation - return nil + _, err := pm.executor.Execute(ctx, fmt.Sprintf("apt-get install -y %s", packageName)) + return err } diff --git a/internal/config/tasks/package_test.go b/internal/config/tasks/package_test.go new file mode 100644 index 0000000..0566c60 --- /dev/null +++ b/internal/config/tasks/package_test.go @@ -0,0 +1,33 @@ +package tasks + +import ( + "context" + "testing" +) + +type mockExecutor struct { + executeFunc func(ctx context.Context, command string) (string, error) +} + +func (m *mockExecutor) Execute(ctx context.Context, command string) (string, error) { + return m.executeFunc(ctx, command) +} + +func TestPackageManager(t *testing.T) { + t.Run("InstallPackage", func(t *testing.T) { + executorMock := &mockExecutor{ + executeFunc: func(ctx context.Context, command string) (string, error) { + return "Package installed successfully", nil + }, + } + + pm := &PackageManager{ + executor: executorMock, + } + + err := pm.Install(context.Background(), "cowsay") + if err != nil { + t.Fatalf("Failed to install package: %v", err) + } + }) +} diff --git a/internal/core/state/store_test.go b/internal/core/state/store_test.go index 6692861..abbff2e 100644 --- a/internal/core/state/store_test.go +++ b/internal/core/state/store_test.go @@ -2,6 +2,7 @@ package state import ( "context" + "log" "os" "path/filepath" "testing" @@ -13,7 +14,11 @@ func TestStore(t *testing.T) { if err != nil { t.Fatalf("Failed to create temp dir: %v", err) } - defer os.RemoveAll(tmpDir) + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + log.Printf("Warning: failed to remove temp directory: %v", err) + } + }() dbPath := filepath.Join(tmpDir, "test.db") diff --git a/internal/iac/planner/planner.go b/internal/iac/planner/planner.go index 2db2167..de1ef4e 100644 --- a/internal/iac/planner/planner.go +++ b/internal/iac/planner/planner.go @@ -3,12 +3,12 @@ package planner import ( "context" - "github.com/rebelopsio/duet/internal/core/state" "github.com/rebelopsio/duet/internal/iac/provider" + "github.com/rebelopsio/duet/pkg/types" ) type Change struct { - Resource provider.Resource + Resource types.Resource Config map[string]interface{} Type string Provider string @@ -19,13 +19,11 @@ type Plan struct { } type Planner struct { - store *state.Store providers map[string]provider.Provider } -func NewPlanner(store *state.Store) *Planner { +func NewPlanner() *Planner { return &Planner{ - store: store, providers: make(map[string]provider.Provider), } } @@ -35,6 +33,6 @@ func (p *Planner) RegisterProvider(provider provider.Provider) { } func (p *Planner) CreatePlan(ctx context.Context, config map[string]interface{}) (*Plan, error) { - // Implementation + // Implementation will go here return &Plan{}, nil } diff --git a/internal/iac/planner/planner_test.go b/internal/iac/planner/planner_test.go new file mode 100644 index 0000000..fd01850 --- /dev/null +++ b/internal/iac/planner/planner_test.go @@ -0,0 +1,93 @@ +package planner + +import ( + "context" + "testing" + "time" + + "github.com/rebelopsio/duet/internal/iac/provider" + "github.com/rebelopsio/duet/pkg/types" +) + +// mockResource implements the types.Resource interface +type mockResource struct { + created time.Time + updated time.Time + metadata map[string]interface{} + tags map[string]string + id string + resType types.ResourceType + provider string + status types.ResourceStatus +} + +func (m *mockResource) GetID() string { return m.id } +func (m *mockResource) GetType() types.ResourceType { return m.resType } +func (m *mockResource) GetProvider() string { return m.provider } +func (m *mockResource) GetStatus() types.ResourceStatus { return m.status } +func (m *mockResource) GetMetadata() map[string]interface{} { return m.metadata } +func (m *mockResource) GetTags() map[string]string { return m.tags } +func (m *mockResource) GetCreatedAt() time.Time { return m.created } +func (m *mockResource) GetUpdatedAt() time.Time { return m.updated } + +// mockProvider implements the provider.Provider interface +type mockProvider struct { + name string +} + +func (m *mockProvider) Name() string { return m.name } + +func (m *mockProvider) Create(ctx context.Context, resourceType string, config map[string]interface{}) (provider.Resource, error) { + return &mockResource{ + id: "test-id", + resType: types.ResourceType(resourceType), + provider: m.name, + status: types.StatusRunning, + metadata: config, + tags: make(map[string]string), + created: time.Now(), + updated: time.Now(), + }, nil +} + +func (m *mockProvider) Read(ctx context.Context, resourceType string, id string) (provider.Resource, error) { + return nil, nil +} + +func (m *mockProvider) Update(ctx context.Context, resource provider.Resource, config map[string]interface{}) error { + return nil +} + +func (m *mockProvider) Delete(ctx context.Context, resource provider.Resource) error { + return nil +} + +func TestPlanner(t *testing.T) { + t.Run("CreatePlan", func(t *testing.T) { + planner := &Planner{ + providers: make(map[string]provider.Provider), + } + + mockAWSProvider := &mockProvider{name: "aws"} + planner.RegisterProvider(mockAWSProvider) + + config := map[string]interface{}{ + "provider": "aws", + "resources": []map[string]interface{}{ + { + "type": "instance", + "name": "test-instance", + }, + }, + } + + plan, err := planner.CreatePlan(context.Background(), config) + if err != nil { + t.Fatalf("Failed to create plan: %v", err) + } + + if plan == nil { + t.Error("Expected plan to not be nil") + } + }) +} diff --git a/internal/iac/provider/provider.go b/internal/iac/provider/provider.go index 5f3ead0..afd3a25 100644 --- a/internal/iac/provider/provider.go +++ b/internal/iac/provider/provider.go @@ -2,18 +2,27 @@ package provider import ( "context" + + "github.com/rebelopsio/duet/pkg/types" ) -type Resource interface { - ID() string - Type() string - Metadata() map[string]interface{} -} +// Resource is an alias for types.Resource to maintain package consistency +type Resource = types.Resource +// Provider defines the interface for infrastructure providers type Provider interface { + // Name returns the provider's name Name() string + + // Create creates a new resource Create(ctx context.Context, resourceType string, config map[string]interface{}) (Resource, error) + + // Read retrieves an existing resource Read(ctx context.Context, resourceType string, id string) (Resource, error) + + // Update updates an existing resource Update(ctx context.Context, resource Resource, config map[string]interface{}) error + + // Delete removes an existing resource Delete(ctx context.Context, resource Resource) error }