Skip to content

Commit

Permalink
chore: update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rebelopsio committed Nov 22, 2024
1 parent d8add01 commit 800b995
Show file tree
Hide file tree
Showing 11 changed files with 977 additions and 268 deletions.
324 changes: 213 additions & 111 deletions coverage/coverage.html

Large diffs are not rendered by default.

23 changes: 10 additions & 13 deletions internal/config/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,23 @@ package executor

import (
"context"
"fmt"

"github.com/rebelopsio/duet/internal/config/ssh"
)

type Executor struct {
sshClient *ssh.Client
// Executor defines the interface for executing commands
type ExecutorInterface interface {
Execute(ctx context.Context, command string) (string, error)
}

func NewExecutor(sshConfig *ssh.Config) (*Executor, error) {
client, err := ssh.NewClient(sshConfig)
if err != nil {
return nil, fmt.Errorf("failed to create SSH client: %w", err)
}
type Executor struct {
executor ExecutorInterface
}

func NewExecutor(executor ExecutorInterface) *Executor {
return &Executor{
sshClient: client,
}, nil
executor: executor,
}
}

func (e *Executor) Execute(ctx context.Context, command string) (string, error) {
return e.sshClient.Execute(ctx, command)
return e.executor.Execute(ctx, command)
}
38 changes: 38 additions & 0 deletions internal/config/executor/executor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package executor

import (
"context"
"testing"
)

// mockExecutor implements the necessary methods for testing
type mockExecutor struct {
executeFunc func(ctx context.Context, command string) (string, error)
}

func (m *mockExecutor) Execute(ctx context.Context, command string) (string, error) {
if m.executeFunc != nil {
return m.executeFunc(ctx, command)
}
return "", nil
}

func TestExecutor(t *testing.T) {
t.Run("ExecuteCommand", func(t *testing.T) {
expectedOutput := "command output"
executor := &mockExecutor{
executeFunc: func(ctx context.Context, command string) (string, error) {
return expectedOutput, nil
},
}

output, err := executor.Execute(context.Background(), "test command")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

if output != expectedOutput {
t.Errorf("Expected %q, got %q", expectedOutput, output)
}
})
}
205 changes: 154 additions & 51 deletions internal/config/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,125 +4,228 @@ import (
"context"
"fmt"
"io"
"net"
"strings"
"sync"
"time"

"golang.org/x/crypto/ssh"
)

// Config holds the SSH client configuration
type Config struct {
Host string
User string
PrivateKey string
Port int
Timeout time.Duration
}

// Client represents an SSH client
type Client struct {
config *Config
client *ssh.Client
}

// NewClient creates a new SSH client with timeouts
func NewClient(config *Config) (*Client, error) {
if config.Timeout == 0 {
config.Timeout = 30 * time.Second
}

// Parse the private key
signer, err := ssh.ParsePrivateKey([]byte(config.PrivateKey))
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}

// Create SSH client config
sshConfig := &ssh.ClientConfig{
User: config.User,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Note: In production, use proper host key verification
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: config.Timeout,
}

// Create a connection with timeout
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
client, err := ssh.Dial("tcp", addr, sshConfig)
conn, err := net.DialTimeout("tcp", addr, config.Timeout)
if err != nil {
return nil, fmt.Errorf("failed to dial SSH: %w", err)
return nil, fmt.Errorf("failed to connect: %w", err)
}

// Set connection deadline
if err := conn.SetDeadline(time.Now().Add(config.Timeout)); err != nil {
closeErr := conn.Close()
if closeErr != nil {
return nil, fmt.Errorf("failed to set connection deadline and close connection: %v, close error: %w", err, closeErr)
}
return nil, fmt.Errorf("failed to set connection deadline: %w", err)
}

// Create new SSH client connection
c, chans, reqs, err := ssh.NewClientConn(conn.(*net.TCPConn), addr, sshConfig)
if err != nil {
closeErr := conn.Close()
if closeErr != nil {
return nil, fmt.Errorf("failed to create SSH connection and close connection: %v, close error: %w", err, closeErr)
}
return nil, fmt.Errorf("failed to create SSH connection: %w", err)
}

// Clear the deadline after successful handshake
if err := conn.SetDeadline(time.Time{}); err != nil {
closeErr := c.Close()
if closeErr != nil {
return nil, fmt.Errorf("failed to clear connection deadline and close client: %v, close error: %w", err, closeErr)
}
return nil, fmt.Errorf("failed to clear connection deadline: %w", err)
}

client := ssh.NewClient(c, chans, reqs)

return &Client{
config: config,
client: client,
}, nil
}

// Close closes the SSH connection
func (c *Client) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}

func (c *Client) Execute(ctx context.Context, command string) (output string, err error) {
// ValidateConnection tests if the SSH connection is working
func (c *Client) ValidateConnection() error {
session, err := c.client.NewSession()
if err != nil {
return "", fmt.Errorf("failed to create session: %w", err)
return fmt.Errorf("failed to create session: %w", err)
}
defer func() {
if err := session.Close(); err != nil {
fmt.Printf("error closing session: %v\n", err)
}
}()
return nil
}

// Execute runs a command over SSH with context for cancellation
func (c *Client) Execute(ctx context.Context, command string) (string, error) {
session, err := c.client.NewSession()
if err != nil {
return "", fmt.Errorf("failed to create session: %w", err)
}
defer func() {
closeErr := session.Close()
if err == nil {
// If there was no error from the command, return any close error
err = closeErr
} else if closeErr != nil {
// If there were both command and close errors, combine them
err = fmt.Errorf("command error: %w; close error: %v", err, closeErr)
if err := session.Close(); err != nil && !isClosedError(err) {
fmt.Printf("error closing session: %v\n", err)
}
}()

// Set up pipes for stdout and stderr
var stdout, stderr io.Reader
stdout, err = session.StdoutPipe()
// Set up pipes for output
stdout, err := session.StdoutPipe()
if err != nil {
return "", fmt.Errorf("failed to create stdout pipe: %w", err)
}
stderr, err = session.StderrPipe()
stderr, err := session.StderrPipe()
if err != nil {
return "", fmt.Errorf("failed to create stderr pipe: %w", err)
}

// Start the command
if err := session.Start(command); err != nil {
return "", fmt.Errorf("failed to start command: %w", err)
type commandResult struct {
output string
err error
stderrData string
}
resultChan := make(chan commandResult, 1)

// Read output
outputBytes, err := io.ReadAll(stdout)
if err != nil {
return "", fmt.Errorf("failed to read stdout: %w", err)
}
go func() {
// Start the command
if err := session.Start(command); err != nil {
resultChan <- commandResult{err: fmt.Errorf("failed to start command: %w", err)}
return
}

// Check for errors
errOutput, err := io.ReadAll(stderr)
if err != nil {
return "", fmt.Errorf("failed to read stderr: %w", err)
}
// Read stdout and stderr concurrently
var stdoutData, stderrData []byte
var stdoutErr, stderrErr error
var wg sync.WaitGroup

wg.Add(2)
go func() {
defer wg.Done()
stdoutData, stdoutErr = io.ReadAll(stdout)
}()

go func() {
defer wg.Done()
stderrData, stderrErr = io.ReadAll(stderr)
}()

// Wait for all readers to complete
wg.Wait()

// Handle any read errors
if stdoutErr != nil {
resultChan <- commandResult{err: fmt.Errorf("failed to read stdout: %w", stdoutErr)}
return
}
if stderrErr != nil {
resultChan <- commandResult{err: fmt.Errorf("failed to read stderr: %w", stderrErr)}
return
}

// Wait for the command to complete
if err := session.Wait(); err != nil {
return "", fmt.Errorf("command failed: %s: %w", string(errOutput), err)
}
// Wait for the command to complete
err := session.Wait()
if err != nil {
resultChan <- commandResult{
err: fmt.Errorf("command failed: %w", err),
stderrData: string(stderrData),
}
return
}

return string(outputBytes), nil
}
resultChan <- commandResult{
output: string(stdoutData),
stderrData: string(stderrData),
}
}()

// ValidateConnection tests the SSH connection without executing a command
func (c *Client) ValidateConnection() (err error) {
session, err := c.client.NewSession()
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
}
// Wait for completion or cancellation
select {
case result := <-resultChan:
if result.err != nil {
if result.stderrData != "" {
return "", fmt.Errorf("%w: %s", result.err, result.stderrData)
}
return "", result.err
}
return result.output, nil

defer func() {
closeErr := session.Close()
if err == nil {
// If validation succeeded, return any close error
err = closeErr
} else if closeErr != nil {
// If both validation and close failed, combine the errors
err = fmt.Errorf("validation error: %w; close error: %v", err, closeErr)
case <-ctx.Done():
if err := session.Signal(ssh.SIGTERM); err != nil && !isClosedError(err) {
fmt.Printf("error sending SIGTERM: %v\n", err)
}
}()
return "", ctx.Err()

return nil
case <-time.After(c.config.Timeout):
if err := session.Signal(ssh.SIGTERM); err != nil && !isClosedError(err) {
fmt.Printf("error sending SIGTERM: %v\n", err)
}
return "", context.DeadlineExceeded
}
}

func isClosedError(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "use of closed network connection") ||
strings.Contains(err.Error(), "connection reset by peer") ||
strings.Contains(err.Error(), "closed network connection") ||
strings.Contains(err.Error(), "EOF")
}
Loading

0 comments on commit 800b995

Please sign in to comment.