diff --git a/coverage/coverage.html b/coverage/coverage.html
index 9f8afc4..cdb4acc 100644
--- a/coverage/coverage.html
+++ b/coverage/coverage.html
@@ -59,15 +59,15 @@
-
+
-
+
-
+
-
+
-
+
@@ -214,28 +214,25 @@
import (
"context"
- "fmt"
-
- "github.com/rebelopsio/duet/internal/config/ssh"
)
-type Executor struct {
- sshClient *ssh.Client
+// Executor defines the interface for executing commands
+type ExecutorInterface interface {
+ Execute(ctx context.Context, command string) (string, error)
}
-func NewExecutor(sshConfig *ssh.Config) (*Executor, error) {
- client, err := ssh.NewClient(sshConfig)
- if err != nil {
- return nil, fmt.Errorf("failed to create SSH client: %w", err)
- }
-
- return &Executor{
- sshClient: client,
- }, nil
+type Executor struct {
+ executor ExecutorInterface
}
+func NewExecutor(executor ExecutorInterface) *Executor {
+ return &Executor{
+ executor: executor,
+ }
+}
+
func (e *Executor) Execute(ctx context.Context, command string) (string, error) {
- return e.sshClient.Execute(ctx, command)
+ return e.executor.Execute(ctx, command)
}
@@ -245,127 +242,230 @@
"context"
"fmt"
"io"
+ "net"
+ "strings"
+ "sync"
+ "time"
"golang.org/x/crypto/ssh"
)
+// Config holds the SSH client configuration
type Config struct {
Host string
User string
PrivateKey string
Port int
+ Timeout time.Duration
}
+// Client represents an SSH client
type Client struct {
config *Config
client *ssh.Client
}
-func NewClient(config *Config) (*Client, error) {
- signer, err := ssh.ParsePrivateKey([]byte(config.PrivateKey))
+// NewClient creates a new SSH client with timeouts
+func NewClient(config *Config) (*Client, error) {
+ if config.Timeout == 0 {
+ config.Timeout = 30 * time.Second
+ }
+
+ // Parse the private key
+ signer, err := ssh.ParsePrivateKey([]byte(config.PrivateKey))
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
- sshConfig := &ssh.ClientConfig{
+ // Create SSH client config
+ sshConfig := &ssh.ClientConfig{
User: config.User,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
- HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Note: In production, use proper host key verification
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ Timeout: config.Timeout,
}
+ // Create a connection with timeout
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
- client, err := ssh.Dial("tcp", addr, sshConfig)
+ conn, err := net.DialTimeout("tcp", addr, config.Timeout)
if err != nil {
- return nil, fmt.Errorf("failed to dial SSH: %w", err)
+ return nil, fmt.Errorf("failed to connect: %w", err)
}
- return &Client{
+ // Set connection deadline
+ if err := conn.SetDeadline(time.Now().Add(config.Timeout)); err != nil {
+ closeErr := conn.Close()
+ if closeErr != nil {
+ return nil, fmt.Errorf("failed to set connection deadline and close connection: %v, close error: %w", err, closeErr)
+ }
+ return nil, fmt.Errorf("failed to set connection deadline: %w", err)
+ }
+
+ // Create new SSH client connection
+ c, chans, reqs, err := ssh.NewClientConn(conn.(*net.TCPConn), addr, sshConfig)
+ if err != nil {
+ closeErr := conn.Close()
+ if closeErr != nil {
+ return nil, fmt.Errorf("failed to create SSH connection and close connection: %v, close error: %w", err, closeErr)
+ }
+ return nil, fmt.Errorf("failed to create SSH connection: %w", err)
+ }
+
+ // Clear the deadline after successful handshake
+ if err := conn.SetDeadline(time.Time{}); err != nil {
+ closeErr := c.Close()
+ if closeErr != nil {
+ return nil, fmt.Errorf("failed to clear connection deadline and close client: %v, close error: %w", err, closeErr)
+ }
+ return nil, fmt.Errorf("failed to clear connection deadline: %w", err)
+ }
+
+ client := ssh.NewClient(c, chans, reqs)
+
+ return &Client{
config: config,
client: client,
}, nil
}
-func (c *Client) Close() error {
- if c.client != nil {
+// Close closes the SSH connection
+func (c *Client) Close() error {
+ if c.client != nil {
return c.client.Close()
}
return nil
}
-func (c *Client) Execute(ctx context.Context, command string) (output string, err error) {
+// ValidateConnection tests if the SSH connection is working
+func (c *Client) ValidateConnection() error {
session, err := c.client.NewSession()
if err != nil {
- return "", fmt.Errorf("failed to create session: %w", err)
+ return fmt.Errorf("failed to create session: %w", err)
}
+ defer func() {
+ if err := session.Close(); err != nil {
+ fmt.Printf("error closing session: %v\n", err)
+ }
+ }()
+ return nil
+}
- defer func() {
- closeErr := session.Close()
- if err == nil {
- // If there was no error from the command, return any close error
- err = closeErr
- } else if closeErr != nil {
- // If there were both command and close errors, combine them
- err = fmt.Errorf("command error: %w; close error: %v", err, closeErr)
+// Execute runs a command over SSH with context for cancellation
+func (c *Client) Execute(ctx context.Context, command string) (string, error) {
+ session, err := c.client.NewSession()
+ if err != nil {
+ return "", fmt.Errorf("failed to create session: %w", err)
+ }
+ defer func() {
+ if err := session.Close(); err != nil && !isClosedError(err) {
+ fmt.Printf("error closing session: %v\n", err)
}
}()
- // Set up pipes for stdout and stderr
- var stdout, stderr io.Reader
- stdout, err = session.StdoutPipe()
+ // Set up pipes for output
+ stdout, err := session.StdoutPipe()
if err != nil {
return "", fmt.Errorf("failed to create stdout pipe: %w", err)
}
- stderr, err = session.StderrPipe()
+ stderr, err := session.StderrPipe()
if err != nil {
return "", fmt.Errorf("failed to create stderr pipe: %w", err)
}
- // Start the command
- if err := session.Start(command); err != nil {
- return "", fmt.Errorf("failed to start command: %w", err)
- }
+ type commandResult struct {
+ output string
+ err error
+ stderrData string
+ }
+ resultChan := make(chan commandResult, 1)
- // Read output
- outputBytes, err := io.ReadAll(stdout)
- if err != nil {
- return "", fmt.Errorf("failed to read stdout: %w", err)
- }
+ go func() {
+ // Start the command
+ if err := session.Start(command); err != nil {
+ resultChan <- commandResult{err: fmt.Errorf("failed to start command: %w", err)}
+ return
+ }
- // Check for errors
- errOutput, err := io.ReadAll(stderr)
- if err != nil {
- return "", fmt.Errorf("failed to read stderr: %w", err)
- }
+ // Read stdout and stderr concurrently
+ var stdoutData, stderrData []byte
+ var stdoutErr, stderrErr error
+ var wg sync.WaitGroup
+
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ stdoutData, stdoutErr = io.ReadAll(stdout)
+ }()
+
+ go func() {
+ defer wg.Done()
+ stderrData, stderrErr = io.ReadAll(stderr)
+ }()
+
+ // Wait for all readers to complete
+ wg.Wait()
+
+ // Handle any read errors
+ if stdoutErr != nil {
+ resultChan <- commandResult{err: fmt.Errorf("failed to read stdout: %w", stdoutErr)}
+ return
+ }
+ if stderrErr != nil {
+ resultChan <- commandResult{err: fmt.Errorf("failed to read stderr: %w", stderrErr)}
+ return
+ }
- // Wait for the command to complete
- if err := session.Wait(); err != nil {
- return "", fmt.Errorf("command failed: %s: %w", string(errOutput), err)
- }
+ // Wait for the command to complete
+ err := session.Wait()
+ if err != nil {
+ resultChan <- commandResult{
+ err: fmt.Errorf("command failed: %w", err),
+ stderrData: string(stderrData),
+ }
+ return
+ }
- return string(outputBytes), nil
-}
+ resultChan <- commandResult{
+ output: string(stdoutData),
+ stderrData: string(stderrData),
+ }
+ }()
-// ValidateConnection tests the SSH connection without executing a command
-func (c *Client) ValidateConnection() (err error) {
- session, err := c.client.NewSession()
- if err != nil {
- return fmt.Errorf("failed to create session: %w", err)
- }
+ // Wait for completion or cancellation
+ select {
+ case result := <-resultChan:
+ if result.err != nil {
+ if result.stderrData != "" {
+ return "", fmt.Errorf("%w: %s", result.err, result.stderrData)
+ }
+ return "", result.err
+ }
+ return result.output, nil
+
+ case <-ctx.Done():
+ if err := session.Signal(ssh.SIGTERM); err != nil && !isClosedError(err) {
+ fmt.Printf("error sending SIGTERM: %v\n", err)
+ }
+ return "", ctx.Err()
- defer func() {
- closeErr := session.Close()
- if err == nil {
- // If validation succeeded, return any close error
- err = closeErr
- } else if closeErr != nil {
- // If both validation and close failed, combine the errors
- err = fmt.Errorf("validation error: %w; close error: %v", err, closeErr)
+ case <-time.After(c.config.Timeout):
+ if err := session.Signal(ssh.SIGTERM); err != nil && !isClosedError(err) {
+ fmt.Printf("error sending SIGTERM: %v\n", err)
}
- }()
+ return "", context.DeadlineExceeded
+ }
+}
- return nil
+func isClosedError(err error) bool {
+ if err == nil {
+ return false
+ }
+ return strings.Contains(err.Error(), "use of closed network connection") ||
+ strings.Contains(err.Error(), "connection reset by peer") ||
+ strings.Contains(err.Error(), "closed network connection") ||
+ strings.Contains(err.Error(), "EOF")
}
@@ -373,23 +473,27 @@
import (
"context"
-
- "github.com/rebelopsio/duet/internal/config/executor"
+ "fmt"
)
+// ExecutorInterface defines the required methods for command execution
+type ExecutorInterface interface {
+ Execute(ctx context.Context, command string) (string, error)
+}
+
type PackageManager struct {
- executor *executor.Executor
+ executor ExecutorInterface
}
-func NewPackageManager(executor *executor.Executor) *PackageManager {
+func NewPackageManager(executor ExecutorInterface) *PackageManager {
return &PackageManager{
executor: executor,
}
}
-func (pm *PackageManager) Install(ctx context.Context, packageName string) error {
- // Implementation
- return nil
+func (pm *PackageManager) Install(ctx context.Context, packageName string) error {
+ _, err := pm.executor.Execute(ctx, fmt.Sprintf("apt-get install -y %s", packageName))
+ return err
}
@@ -405,14 +509,14 @@
state *lua.LState
}
-func NewEngine() *Engine {
+func NewEngine() *Engine {
return &Engine{
state: lua.NewState(),
}
}
-func (e *Engine) Close() {
- if e.state != nil {
+func (e *Engine) Close() {
+ if e.state != nil {
e.state.Close()
}
}
@@ -421,13 +525,13 @@
return e.state.DoFile(filename)
}
-func (e *Engine) CallFunction(name string, args ...lua.LValue) (lua.LValue, error) {
+func (e *Engine) CallFunction(name string, args ...lua.LValue) (lua.LValue, error) {
fn := e.state.GetGlobal(name)
- if fn == lua.LNil {
+ if fn == lua.LNil {
return nil, fmt.Errorf("function %s not found", name)
}
- err := e.state.CallByParam(lua.P{
+ err := e.state.CallByParam(lua.P{
Fn: fn,
NRet: 1,
Protect: true,
@@ -436,7 +540,7 @@
return nil, fmt.Errorf("error calling function %s: %w", name, err)
}
- ret := e.state.Get(-1)
+ ret := e.state.Get(-1)
e.state.Pop(1)
return ret, nil
}
@@ -468,47 +572,47 @@
ConfigApplied bool
}
-func NewStore(dbPath string) (*Store, error) {
+func NewStore(dbPath string) (*Store, error) {
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
- if err := db.AutoMigrate(&Resource{}); err != nil {
+ if err := db.AutoMigrate(&Resource{}); err != nil {
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
- return &Store{db: db}, nil
+ return &Store{db: db}, nil
}
// GetResources retrieves all resources from the store
-func (s *Store) GetResources(ctx context.Context) ([]Resource, error) {
+func (s *Store) GetResources(ctx context.Context) ([]Resource, error) {
var resources []Resource
result := s.db.WithContext(ctx).Find(&resources)
if result.Error != nil {
return nil, fmt.Errorf("failed to get resources: %w", result.Error)
}
- return resources, nil
+ return resources, nil
}
// SaveResource saves a resource to the store
-func (s *Store) SaveResource(ctx context.Context, resource *Resource) error {
+func (s *Store) SaveResource(ctx context.Context, resource *Resource) error {
result := s.db.WithContext(ctx).Save(resource)
return result.Error
}
// GetResource retrieves a single resource by ID
-func (s *Store) GetResource(ctx context.Context, id string) (*Resource, error) {
+func (s *Store) GetResource(ctx context.Context, id string) (*Resource, error) {
var resource Resource
result := s.db.WithContext(ctx).First(&resource, "id = ?", id)
- if result.Error != nil {
+ if result.Error != nil {
return nil, result.Error
}
- return &resource, nil
+ return &resource, nil
}
// DeleteResource removes a resource from the store
-func (s *Store) DeleteResource(ctx context.Context, id string) error {
+func (s *Store) DeleteResource(ctx context.Context, id string) error {
result := s.db.WithContext(ctx).Delete(&Resource{}, "id = ?", id)
return result.Error
}
@@ -519,12 +623,12 @@
import (
"context"
- "github.com/rebelopsio/duet/internal/core/state"
"github.com/rebelopsio/duet/internal/iac/provider"
+ "github.com/rebelopsio/duet/pkg/types"
)
type Change struct {
- Resource provider.Resource
+ Resource types.Resource
Config map[string]interface{}
Type string
Provider string
@@ -535,23 +639,21 @@
}
type Planner struct {
- store *state.Store
providers map[string]provider.Provider
}
-func NewPlanner(store *state.Store) *Planner {
+func NewPlanner() *Planner {
return &Planner{
- store: store,
providers: make(map[string]provider.Provider),
}
}
-func (p *Planner) RegisterProvider(provider provider.Provider) {
+func (p *Planner) RegisterProvider(provider provider.Provider) {
p.providers[provider.Name()] = provider
}
-func (p *Planner) CreatePlan(ctx context.Context, config map[string]interface{}) (*Plan, error) {
- // Implementation
+func (p *Planner) CreatePlan(ctx context.Context, config map[string]interface{}) (*Plan, error) {
+ // Implementation will go here
return &Plan{}, nil
}
diff --git a/internal/config/executor/executor.go b/internal/config/executor/executor.go
index 75449db..0f4e04d 100644
--- a/internal/config/executor/executor.go
+++ b/internal/config/executor/executor.go
@@ -2,26 +2,23 @@ package executor
import (
"context"
- "fmt"
-
- "github.com/rebelopsio/duet/internal/config/ssh"
)
-type Executor struct {
- sshClient *ssh.Client
+// Executor defines the interface for executing commands
+type ExecutorInterface interface {
+ Execute(ctx context.Context, command string) (string, error)
}
-func NewExecutor(sshConfig *ssh.Config) (*Executor, error) {
- client, err := ssh.NewClient(sshConfig)
- if err != nil {
- return nil, fmt.Errorf("failed to create SSH client: %w", err)
- }
+type Executor struct {
+ executor ExecutorInterface
+}
+func NewExecutor(executor ExecutorInterface) *Executor {
return &Executor{
- sshClient: client,
- }, nil
+ executor: executor,
+ }
}
func (e *Executor) Execute(ctx context.Context, command string) (string, error) {
- return e.sshClient.Execute(ctx, command)
+ return e.executor.Execute(ctx, command)
}
diff --git a/internal/config/executor/executor_test.go b/internal/config/executor/executor_test.go
new file mode 100644
index 0000000..d649899
--- /dev/null
+++ b/internal/config/executor/executor_test.go
@@ -0,0 +1,38 @@
+package executor
+
+import (
+ "context"
+ "testing"
+)
+
+// mockExecutor implements the necessary methods for testing
+type mockExecutor struct {
+ executeFunc func(ctx context.Context, command string) (string, error)
+}
+
+func (m *mockExecutor) Execute(ctx context.Context, command string) (string, error) {
+ if m.executeFunc != nil {
+ return m.executeFunc(ctx, command)
+ }
+ return "", nil
+}
+
+func TestExecutor(t *testing.T) {
+ t.Run("ExecuteCommand", func(t *testing.T) {
+ expectedOutput := "command output"
+ executor := &mockExecutor{
+ executeFunc: func(ctx context.Context, command string) (string, error) {
+ return expectedOutput, nil
+ },
+ }
+
+ output, err := executor.Execute(context.Background(), "test command")
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ if output != expectedOutput {
+ t.Errorf("Expected %q, got %q", expectedOutput, output)
+ }
+ })
+}
diff --git a/internal/config/ssh/client.go b/internal/config/ssh/client.go
index 869360c..9e9668b 100644
--- a/internal/config/ssh/client.go
+++ b/internal/config/ssh/client.go
@@ -4,48 +4,95 @@ import (
"context"
"fmt"
"io"
+ "net"
+ "strings"
+ "sync"
+ "time"
"golang.org/x/crypto/ssh"
)
+// Config holds the SSH client configuration
type Config struct {
Host string
User string
PrivateKey string
Port int
+ Timeout time.Duration
}
+// Client represents an SSH client
type Client struct {
config *Config
client *ssh.Client
}
+// NewClient creates a new SSH client with timeouts
func NewClient(config *Config) (*Client, error) {
+ if config.Timeout == 0 {
+ config.Timeout = 30 * time.Second
+ }
+
+ // Parse the private key
signer, err := ssh.ParsePrivateKey([]byte(config.PrivateKey))
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
+ // Create SSH client config
sshConfig := &ssh.ClientConfig{
User: config.User,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
- HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Note: In production, use proper host key verification
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ Timeout: config.Timeout,
}
+ // Create a connection with timeout
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
- client, err := ssh.Dial("tcp", addr, sshConfig)
+ conn, err := net.DialTimeout("tcp", addr, config.Timeout)
if err != nil {
- return nil, fmt.Errorf("failed to dial SSH: %w", err)
+ return nil, fmt.Errorf("failed to connect: %w", err)
}
+ // Set connection deadline
+ if err := conn.SetDeadline(time.Now().Add(config.Timeout)); err != nil {
+ closeErr := conn.Close()
+ if closeErr != nil {
+ return nil, fmt.Errorf("failed to set connection deadline and close connection: %v, close error: %w", err, closeErr)
+ }
+ return nil, fmt.Errorf("failed to set connection deadline: %w", err)
+ }
+
+ // Create new SSH client connection
+ c, chans, reqs, err := ssh.NewClientConn(conn.(*net.TCPConn), addr, sshConfig)
+ if err != nil {
+ closeErr := conn.Close()
+ if closeErr != nil {
+ return nil, fmt.Errorf("failed to create SSH connection and close connection: %v, close error: %w", err, closeErr)
+ }
+ return nil, fmt.Errorf("failed to create SSH connection: %w", err)
+ }
+
+ // Clear the deadline after successful handshake
+ if err := conn.SetDeadline(time.Time{}); err != nil {
+ closeErr := c.Close()
+ if closeErr != nil {
+ return nil, fmt.Errorf("failed to clear connection deadline and close client: %v, close error: %w", err, closeErr)
+ }
+ return nil, fmt.Errorf("failed to clear connection deadline: %w", err)
+ }
+
+ client := ssh.NewClient(c, chans, reqs)
+
return &Client{
config: config,
client: client,
}, nil
}
+// Close closes the SSH connection
func (c *Client) Close() error {
if c.client != nil {
return c.client.Close()
@@ -53,76 +100,132 @@ func (c *Client) Close() error {
return nil
}
-func (c *Client) Execute(ctx context.Context, command string) (output string, err error) {
+// ValidateConnection tests if the SSH connection is working
+func (c *Client) ValidateConnection() error {
session, err := c.client.NewSession()
if err != nil {
- return "", fmt.Errorf("failed to create session: %w", err)
+ return fmt.Errorf("failed to create session: %w", err)
}
+ defer func() {
+ if err := session.Close(); err != nil {
+ fmt.Printf("error closing session: %v\n", err)
+ }
+ }()
+ return nil
+}
+// Execute runs a command over SSH with context for cancellation
+func (c *Client) Execute(ctx context.Context, command string) (string, error) {
+ session, err := c.client.NewSession()
+ if err != nil {
+ return "", fmt.Errorf("failed to create session: %w", err)
+ }
defer func() {
- closeErr := session.Close()
- if err == nil {
- // If there was no error from the command, return any close error
- err = closeErr
- } else if closeErr != nil {
- // If there were both command and close errors, combine them
- err = fmt.Errorf("command error: %w; close error: %v", err, closeErr)
+ if err := session.Close(); err != nil && !isClosedError(err) {
+ fmt.Printf("error closing session: %v\n", err)
}
}()
- // Set up pipes for stdout and stderr
- var stdout, stderr io.Reader
- stdout, err = session.StdoutPipe()
+ // Set up pipes for output
+ stdout, err := session.StdoutPipe()
if err != nil {
return "", fmt.Errorf("failed to create stdout pipe: %w", err)
}
- stderr, err = session.StderrPipe()
+ stderr, err := session.StderrPipe()
if err != nil {
return "", fmt.Errorf("failed to create stderr pipe: %w", err)
}
- // Start the command
- if err := session.Start(command); err != nil {
- return "", fmt.Errorf("failed to start command: %w", err)
+ type commandResult struct {
+ output string
+ err error
+ stderrData string
}
+ resultChan := make(chan commandResult, 1)
- // Read output
- outputBytes, err := io.ReadAll(stdout)
- if err != nil {
- return "", fmt.Errorf("failed to read stdout: %w", err)
- }
+ go func() {
+ // Start the command
+ if err := session.Start(command); err != nil {
+ resultChan <- commandResult{err: fmt.Errorf("failed to start command: %w", err)}
+ return
+ }
- // Check for errors
- errOutput, err := io.ReadAll(stderr)
- if err != nil {
- return "", fmt.Errorf("failed to read stderr: %w", err)
- }
+ // Read stdout and stderr concurrently
+ var stdoutData, stderrData []byte
+ var stdoutErr, stderrErr error
+ var wg sync.WaitGroup
+
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ stdoutData, stdoutErr = io.ReadAll(stdout)
+ }()
+
+ go func() {
+ defer wg.Done()
+ stderrData, stderrErr = io.ReadAll(stderr)
+ }()
+
+ // Wait for all readers to complete
+ wg.Wait()
+
+ // Handle any read errors
+ if stdoutErr != nil {
+ resultChan <- commandResult{err: fmt.Errorf("failed to read stdout: %w", stdoutErr)}
+ return
+ }
+ if stderrErr != nil {
+ resultChan <- commandResult{err: fmt.Errorf("failed to read stderr: %w", stderrErr)}
+ return
+ }
- // Wait for the command to complete
- if err := session.Wait(); err != nil {
- return "", fmt.Errorf("command failed: %s: %w", string(errOutput), err)
- }
+ // Wait for the command to complete
+ err := session.Wait()
+ if err != nil {
+ resultChan <- commandResult{
+ err: fmt.Errorf("command failed: %w", err),
+ stderrData: string(stderrData),
+ }
+ return
+ }
- return string(outputBytes), nil
-}
+ resultChan <- commandResult{
+ output: string(stdoutData),
+ stderrData: string(stderrData),
+ }
+ }()
-// ValidateConnection tests the SSH connection without executing a command
-func (c *Client) ValidateConnection() (err error) {
- session, err := c.client.NewSession()
- if err != nil {
- return fmt.Errorf("failed to create session: %w", err)
- }
+ // Wait for completion or cancellation
+ select {
+ case result := <-resultChan:
+ if result.err != nil {
+ if result.stderrData != "" {
+ return "", fmt.Errorf("%w: %s", result.err, result.stderrData)
+ }
+ return "", result.err
+ }
+ return result.output, nil
- defer func() {
- closeErr := session.Close()
- if err == nil {
- // If validation succeeded, return any close error
- err = closeErr
- } else if closeErr != nil {
- // If both validation and close failed, combine the errors
- err = fmt.Errorf("validation error: %w; close error: %v", err, closeErr)
+ case <-ctx.Done():
+ if err := session.Signal(ssh.SIGTERM); err != nil && !isClosedError(err) {
+ fmt.Printf("error sending SIGTERM: %v\n", err)
}
- }()
+ return "", ctx.Err()
- return nil
+ case <-time.After(c.config.Timeout):
+ if err := session.Signal(ssh.SIGTERM); err != nil && !isClosedError(err) {
+ fmt.Printf("error sending SIGTERM: %v\n", err)
+ }
+ return "", context.DeadlineExceeded
+ }
+}
+
+func isClosedError(err error) bool {
+ if err == nil {
+ return false
+ }
+ return strings.Contains(err.Error(), "use of closed network connection") ||
+ strings.Contains(err.Error(), "connection reset by peer") ||
+ strings.Contains(err.Error(), "closed network connection") ||
+ strings.Contains(err.Error(), "EOF")
}
diff --git a/internal/config/ssh/client_test.go b/internal/config/ssh/client_test.go
index b884923..9d16933 100644
--- a/internal/config/ssh/client_test.go
+++ b/internal/config/ssh/client_test.go
@@ -1,165 +1,492 @@
package ssh
import (
+ "bytes"
"context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
"fmt"
"io"
"net"
+ "strings"
+ "sync"
"testing"
+ "time"
"golang.org/x/crypto/ssh"
)
-// mockSSHServer simulates an SSH server for testing
+type keyPair struct {
+ PublicKey ssh.PublicKey
+ PrivateKey string
+}
+
type mockSSHServer struct {
- listener net.Listener
- config *ssh.ServerConfig
+ listener net.Listener
+ ctx context.Context
+ config *ssh.ServerConfig
+ ready chan struct{}
+ done chan struct{}
+ activeConns map[string]net.Conn
+ t *testing.T
+ cancel context.CancelFunc
+ wg sync.WaitGroup
+ activesMutex sync.RWMutex
+}
+
+// generateTestKey generates a test RSA key pair for testing
+func generateTestKey() (*keyPair, error) {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate private key: %w", err)
+ }
+
+ // Convert private key to PEM format
+ privateKeyPEM := &pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
+ }
+
+ // Generate SSH public key
+ publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create public key: %w", err)
+ }
+
+ return &keyPair{
+ PrivateKey: string(pem.EncodeToMemory(privateKeyPEM)),
+ PublicKey: publicKey,
+ }, nil
}
-func newMockSSHServer(t *testing.T) (*mockSSHServer, error) {
+func newMockSSHServer(t *testing.T, keys *keyPair) (*mockSSHServer, error) {
+ ctx, cancel := context.WithCancel(context.Background())
+
config := &ssh.ServerConfig{
- PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
- return nil, nil
+ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
+ if key.Type() == keys.PublicKey.Type() && bytes.Equal(key.Marshal(), keys.PublicKey.Marshal()) {
+ return &ssh.Permissions{}, nil
+ }
+ return nil, fmt.Errorf("unknown public key")
},
}
- privateKey, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
+ signer, err := ssh.ParsePrivateKey([]byte(keys.PrivateKey))
if err != nil {
+ cancel()
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
- config.AddHostKey(privateKey)
+ config.AddHostKey(signer)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
+ cancel()
return nil, fmt.Errorf("failed to listen: %w", err)
}
server := &mockSSHServer{
- listener: listener,
- config: config,
+ listener: listener,
+ config: config,
+ ready: make(chan struct{}),
+ done: make(chan struct{}),
+ activeConns: make(map[string]net.Conn),
+ t: t,
+ ctx: ctx,
+ cancel: cancel,
}
- go server.serve(t)
- return server, nil
+ server.wg.Add(1)
+ go server.acceptConnections()
+
+ // Wait for server to be ready
+ select {
+ case <-server.ready:
+ return server, nil
+ case <-time.After(2 * time.Second):
+ cancel()
+ if err := listener.Close(); err != nil {
+ t.Logf("Failed to close listener: %v", err)
+ }
+ return nil, fmt.Errorf("timeout waiting for server to be ready")
+ }
}
-func (s *mockSSHServer) serve(t *testing.T) {
+func (s *mockSSHServer) acceptConnections() {
+ defer s.wg.Done()
+ defer close(s.done)
+ defer s.cancel()
+
+ // Signal that we're ready to accept connections
+ close(s.ready)
+
for {
+ if err := s.listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second)); err != nil {
+ s.t.Logf("Failed to set accept deadline: %v", err)
+ return
+ }
+
conn, err := s.listener.Accept()
if err != nil {
+ if isTimeout(err) {
+ select {
+ case <-s.ctx.Done():
+ return
+ default:
+ continue
+ }
+ }
if !isClosedError(err) {
- t.Errorf("Failed to accept connection: %v", err)
+ s.t.Logf("Accept error: %v", err)
}
return
}
- _, chans, reqs, err := ssh.NewServerConn(conn, s.config)
- if err != nil {
- t.Errorf("Failed to handshake: %v", err)
- continue
- }
+ s.activesMutex.Lock()
+ s.activeConns[conn.RemoteAddr().String()] = conn
+ s.activesMutex.Unlock()
- go ssh.DiscardRequests(reqs)
- go handleChannels(chans, t)
+ s.wg.Add(1)
+ go s.handleConnection(conn)
}
}
-func handleChannels(chans <-chan ssh.NewChannel, t *testing.T) {
- for newChannel := range chans {
- if newChannel.ChannelType() != "session" {
- newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
- continue
+func isTimeout(err error) bool {
+ if err == nil {
+ return false
+ }
+ return strings.Contains(err.Error(), "timeout") ||
+ strings.Contains(err.Error(), "deadline exceeded") ||
+ err == context.DeadlineExceeded
+}
+
+func (s *mockSSHServer) handleConnection(conn net.Conn) {
+ defer s.wg.Done()
+ defer func() {
+ s.activesMutex.Lock()
+ delete(s.activeConns, conn.RemoteAddr().String())
+ s.activesMutex.Unlock()
+ if err := conn.Close(); err != nil && !isClosedError(err) {
+ s.t.Logf("Connection close error: %v", err)
}
+ }()
- channel, requests, err := newChannel.Accept()
- if err != nil {
- t.Errorf("Failed to accept channel: %v", err)
- continue
+ sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.config)
+ if err != nil {
+ if !isClosedError(err) {
+ s.t.Logf("SSH handshake error: %v", err)
}
+ return
+ }
+
+ go func() {
+ <-s.ctx.Done()
+ if err := sshConn.Close(); err != nil && !isClosedError(err) {
+ s.t.Logf("SSH connection close error: %v", err)
+ }
+ }()
+
+ go ssh.DiscardRequests(reqs)
+ s.handleChannels(chans)
+}
+
+func (s *mockSSHServer) handleRequests(channel ssh.Channel, requests <-chan *ssh.Request) {
+ defer func() {
+ if err := channel.Close(); err != nil && !isClosedError(err) {
+ s.t.Logf("Channel close error: %v", err)
+ }
+ }()
- go func(in <-chan *ssh.Request) {
- for req := range in {
- switch req.Type {
- case "exec":
- payload := struct{ Command string }{}
- ssh.Unmarshal(req.Payload, &payload)
+ for {
+ select {
+ case <-s.ctx.Done():
+ return
+ case req, ok := <-requests:
+ if !ok {
+ return
+ }
+ switch req.Type {
+ case "exec":
+ exitStatus := make([]byte, 4)
- if payload.Command == "echo test" {
- io.WriteString(channel, "test\n")
+ payload := struct{ Command string }{}
+ if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
+ s.t.Logf("Unmarshal error: %v", err)
+ if err := req.Reply(false, nil); err != nil {
+ s.t.Logf("Reply error: %v", err)
}
+ continue
+ }
- channel.Close()
+ if err := req.Reply(true, nil); err != nil {
+ s.t.Logf("Reply error: %v", err)
+ continue
}
- req.Reply(true, nil)
+
+ if payload.Command == "echo test" {
+ if _, err := io.WriteString(channel, "test\n"); err != nil {
+ s.t.Logf("Write error: %v", err)
+ }
+
+ // Send exit status
+ _, err := channel.SendRequest("exit-status", false, exitStatus)
+ if err != nil {
+ s.t.Logf("Failed to send exit status: %v", err)
+ }
+
+ if err := channel.CloseWrite(); err != nil && !isClosedError(err) {
+ s.t.Logf("CloseWrite error: %v", err)
+ }
+ return
+ }
+
+ if payload.Command == "sleep 10" {
+ // For the timeout test, we'll block until context is cancelled
+ select {
+ case <-s.ctx.Done():
+ // Send non-zero exit status for cancelled command
+ exitStatus[3] = 1
+ _, err := channel.SendRequest("exit-status", false, exitStatus)
+ if err != nil {
+ s.t.Logf("Failed to send exit status: %v", err)
+ }
+ case <-time.After(10 * time.Second):
+ // Normal completion
+ _, err := channel.SendRequest("exit-status", false, exitStatus)
+ if err != nil {
+ s.t.Logf("Failed to send exit status: %v", err)
+ }
+ }
+
+ if err := channel.CloseWrite(); err != nil && !isClosedError(err) {
+ s.t.Logf("CloseWrite error: %v", err)
+ }
+ return
+ }
+
+ // Unknown command, send error exit status
+ exitStatus[3] = 1
+ _, err := channel.SendRequest("exit-status", false, exitStatus)
+ if err != nil {
+ s.t.Logf("Failed to send exit status: %v", err)
+ }
+
+ if err := channel.CloseWrite(); err != nil && !isClosedError(err) {
+ s.t.Logf("CloseWrite error: %v", err)
+ }
+ return
+
+ default:
+ if err := req.Reply(false, nil); err != nil {
+ s.t.Logf("Reply error: %v", err)
+ }
+ }
+ }
+ }
+}
+
+func (s *mockSSHServer) handleChannels(chans <-chan ssh.NewChannel) {
+ for {
+ select {
+ case <-s.ctx.Done():
+ return
+ case newChannel, ok := <-chans:
+ if !ok {
+ return
}
- }(requests)
+ go s.handleChannel(newChannel)
+ }
}
}
-func isClosedError(err error) bool {
- return err.Error() == "use of closed network connection"
+func (s *mockSSHServer) handleChannel(newChannel ssh.NewChannel) {
+ if newChannel.ChannelType() != "session" {
+ if err := newChannel.Reject(ssh.UnknownChannelType, "unknown channel type"); err != nil {
+ s.t.Logf("Channel reject error: %v", err)
+ }
+ return
+ }
+
+ channel, requests, err := newChannel.Accept()
+ if err != nil {
+ s.t.Logf("Channel accept error: %v", err)
+ return
+ }
+
+ go s.handleRequests(channel, requests)
}
-const testPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY-----
-b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn
-NhAAAAAwEAAQAAAQEAxU4rixQXoahCL2gVoNWswNMFxYEiO0YH9YbB1qh+9nYRYGzEOc0l
-...
------END OPENSSH PRIVATE KEY-----`
+func (s *mockSSHServer) shutdown() error {
+ s.cancel()
+
+ if err := s.listener.Close(); err != nil {
+ return fmt.Errorf("failed to close listener: %w", err)
+ }
+
+ done := make(chan struct{})
+ go func() {
+ s.wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ return nil
+ case <-time.After(5 * time.Second):
+ return fmt.Errorf("timeout waiting for server shutdown")
+ }
+}
func TestClient(t *testing.T) {
- mockServer, err := newMockSSHServer(t)
+ if testing.Short() {
+ t.Skip("Skipping SSH tests in short mode")
+ }
+
+ // Generate a test key for this test run
+ keys, err := generateTestKey()
+ if err != nil {
+ t.Fatalf("Failed to generate test key: %v", err)
+ }
+
+ mockServer, err := newMockSSHServer(t, keys)
if err != nil {
t.Fatalf("Failed to start mock SSH server: %v", err)
}
- defer mockServer.listener.Close()
+
+ // Use a cleanup function to ensure proper shutdown
+ cleanup := func() {
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ done := make(chan error, 1)
+ go func() {
+ done <- mockServer.shutdown()
+ }()
+
+ select {
+ case err := <-done:
+ if err != nil {
+ t.Logf("Server shutdown error: %v", err)
+ }
+ case <-ctx.Done():
+ t.Log("Server shutdown timed out")
+ }
+ }
+ defer cleanup()
serverAddr := mockServer.listener.Addr().String()
- host, port, err := net.SplitHostPort(serverAddr)
+ host, portStr, err := net.SplitHostPort(serverAddr)
if err != nil {
t.Fatalf("Failed to parse server address: %v", err)
}
+ var port int
+ if _, err := fmt.Sscanf(portStr, "%d", &port); err != nil {
+ t.Fatalf("Failed to parse port number: %v", err)
+ }
+
config := &Config{
Host: host,
+ Port: port,
User: "test",
- PrivateKey: testPrivateKey,
- Port: parseInt(port),
+ PrivateKey: keys.PrivateKey,
+ Timeout: 2 * time.Second,
}
- t.Run("Connect", func(t *testing.T) {
+ // Helper function to create and cleanup client
+ createClient := func(t *testing.T) (*Client, func()) {
client, err := NewClient(config)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
- defer client.Close()
- err = client.ValidateConnection()
- if err != nil {
- t.Errorf("Failed to validate connection: %v", err)
+ cleanup := func() {
+ if err := client.Close(); err != nil && !isClosedError(err) {
+ t.Logf("Client close error: %v", err)
+ }
+ }
+
+ return client, cleanup
+ }
+
+ t.Run("Connect", func(t *testing.T) {
+ client, cleanup := createClient(t)
+ defer cleanup()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ done := make(chan error, 1)
+ go func() {
+ done <- client.ValidateConnection()
+ }()
+
+ select {
+ case err := <-done:
+ if err != nil {
+ t.Errorf("Failed to validate connection: %v", err)
+ }
+ case <-ctx.Done():
+ t.Error("Connection validation timed out")
}
})
t.Run("Execute", func(t *testing.T) {
- client, err := NewClient(config)
- if err != nil {
- t.Fatalf("Failed to create client: %v", err)
- }
- defer client.Close()
+ client, cleanup := createClient(t)
+ defer cleanup()
- output, err := client.Execute(context.Background(), "echo test")
- if err != nil {
- t.Fatalf("Failed to execute command: %v", err)
- }
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ done := make(chan struct {
+ output string
+ err error
+ }, 1)
- expected := "test\n"
- if output != expected {
- t.Errorf("Expected output %q, got %q", expected, output)
+ go func() {
+ output, err := client.Execute(ctx, "echo test")
+ done <- struct {
+ output string
+ err error
+ }{output, err}
+ }()
+
+ select {
+ case result := <-done:
+ if result.err != nil {
+ t.Fatalf("Failed to execute command: %v", result.err)
+ }
+ if result.output != "test\n" {
+ t.Errorf("Expected output %q, got %q", "test\n", result.output)
+ }
+ case <-ctx.Done():
+ t.Fatal("Command execution timed out")
}
})
-}
-func parseInt(s string) int {
- var port int
- fmt.Sscanf(s, "%d", &port)
- return port
+ t.Run("ExecuteWithTimeout", func(t *testing.T) {
+ client, cleanup := createClient(t)
+ defer cleanup()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+ defer cancel()
+
+ done := make(chan error, 1)
+ go func() {
+ _, err := client.Execute(ctx, "sleep 10")
+ done <- err
+ }()
+
+ select {
+ case err := <-done:
+ if err == nil {
+ t.Error("Expected error, got nil")
+ } else if err != context.DeadlineExceeded {
+ t.Errorf("Expected context.DeadlineExceeded error, got: %v", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatal("Test timed out waiting for command timeout")
+ }
+ })
}
diff --git a/internal/config/tasks/package.go b/internal/config/tasks/package.go
index 4a5ee04..bec058a 100644
--- a/internal/config/tasks/package.go
+++ b/internal/config/tasks/package.go
@@ -2,21 +2,25 @@ package tasks
import (
"context"
-
- "github.com/rebelopsio/duet/internal/config/executor"
+ "fmt"
)
+// ExecutorInterface defines the required methods for command execution
+type ExecutorInterface interface {
+ Execute(ctx context.Context, command string) (string, error)
+}
+
type PackageManager struct {
- executor *executor.Executor
+ executor ExecutorInterface
}
-func NewPackageManager(executor *executor.Executor) *PackageManager {
+func NewPackageManager(executor ExecutorInterface) *PackageManager {
return &PackageManager{
executor: executor,
}
}
func (pm *PackageManager) Install(ctx context.Context, packageName string) error {
- // Implementation
- return nil
+ _, err := pm.executor.Execute(ctx, fmt.Sprintf("apt-get install -y %s", packageName))
+ return err
}
diff --git a/internal/config/tasks/package_test.go b/internal/config/tasks/package_test.go
new file mode 100644
index 0000000..0566c60
--- /dev/null
+++ b/internal/config/tasks/package_test.go
@@ -0,0 +1,33 @@
+package tasks
+
+import (
+ "context"
+ "testing"
+)
+
+type mockExecutor struct {
+ executeFunc func(ctx context.Context, command string) (string, error)
+}
+
+func (m *mockExecutor) Execute(ctx context.Context, command string) (string, error) {
+ return m.executeFunc(ctx, command)
+}
+
+func TestPackageManager(t *testing.T) {
+ t.Run("InstallPackage", func(t *testing.T) {
+ executorMock := &mockExecutor{
+ executeFunc: func(ctx context.Context, command string) (string, error) {
+ return "Package installed successfully", nil
+ },
+ }
+
+ pm := &PackageManager{
+ executor: executorMock,
+ }
+
+ err := pm.Install(context.Background(), "cowsay")
+ if err != nil {
+ t.Fatalf("Failed to install package: %v", err)
+ }
+ })
+}
diff --git a/internal/core/state/store_test.go b/internal/core/state/store_test.go
index 6692861..abbff2e 100644
--- a/internal/core/state/store_test.go
+++ b/internal/core/state/store_test.go
@@ -2,6 +2,7 @@ package state
import (
"context"
+ "log"
"os"
"path/filepath"
"testing"
@@ -13,7 +14,11 @@ func TestStore(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
- defer os.RemoveAll(tmpDir)
+ defer func() {
+ if err := os.RemoveAll(tmpDir); err != nil {
+ log.Printf("Warning: failed to remove temp directory: %v", err)
+ }
+ }()
dbPath := filepath.Join(tmpDir, "test.db")
diff --git a/internal/iac/planner/planner.go b/internal/iac/planner/planner.go
index 2db2167..de1ef4e 100644
--- a/internal/iac/planner/planner.go
+++ b/internal/iac/planner/planner.go
@@ -3,12 +3,12 @@ package planner
import (
"context"
- "github.com/rebelopsio/duet/internal/core/state"
"github.com/rebelopsio/duet/internal/iac/provider"
+ "github.com/rebelopsio/duet/pkg/types"
)
type Change struct {
- Resource provider.Resource
+ Resource types.Resource
Config map[string]interface{}
Type string
Provider string
@@ -19,13 +19,11 @@ type Plan struct {
}
type Planner struct {
- store *state.Store
providers map[string]provider.Provider
}
-func NewPlanner(store *state.Store) *Planner {
+func NewPlanner() *Planner {
return &Planner{
- store: store,
providers: make(map[string]provider.Provider),
}
}
@@ -35,6 +33,6 @@ func (p *Planner) RegisterProvider(provider provider.Provider) {
}
func (p *Planner) CreatePlan(ctx context.Context, config map[string]interface{}) (*Plan, error) {
- // Implementation
+ // Implementation will go here
return &Plan{}, nil
}
diff --git a/internal/iac/planner/planner_test.go b/internal/iac/planner/planner_test.go
new file mode 100644
index 0000000..fd01850
--- /dev/null
+++ b/internal/iac/planner/planner_test.go
@@ -0,0 +1,93 @@
+package planner
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/rebelopsio/duet/internal/iac/provider"
+ "github.com/rebelopsio/duet/pkg/types"
+)
+
+// mockResource implements the types.Resource interface
+type mockResource struct {
+ created time.Time
+ updated time.Time
+ metadata map[string]interface{}
+ tags map[string]string
+ id string
+ resType types.ResourceType
+ provider string
+ status types.ResourceStatus
+}
+
+func (m *mockResource) GetID() string { return m.id }
+func (m *mockResource) GetType() types.ResourceType { return m.resType }
+func (m *mockResource) GetProvider() string { return m.provider }
+func (m *mockResource) GetStatus() types.ResourceStatus { return m.status }
+func (m *mockResource) GetMetadata() map[string]interface{} { return m.metadata }
+func (m *mockResource) GetTags() map[string]string { return m.tags }
+func (m *mockResource) GetCreatedAt() time.Time { return m.created }
+func (m *mockResource) GetUpdatedAt() time.Time { return m.updated }
+
+// mockProvider implements the provider.Provider interface
+type mockProvider struct {
+ name string
+}
+
+func (m *mockProvider) Name() string { return m.name }
+
+func (m *mockProvider) Create(ctx context.Context, resourceType string, config map[string]interface{}) (provider.Resource, error) {
+ return &mockResource{
+ id: "test-id",
+ resType: types.ResourceType(resourceType),
+ provider: m.name,
+ status: types.StatusRunning,
+ metadata: config,
+ tags: make(map[string]string),
+ created: time.Now(),
+ updated: time.Now(),
+ }, nil
+}
+
+func (m *mockProvider) Read(ctx context.Context, resourceType string, id string) (provider.Resource, error) {
+ return nil, nil
+}
+
+func (m *mockProvider) Update(ctx context.Context, resource provider.Resource, config map[string]interface{}) error {
+ return nil
+}
+
+func (m *mockProvider) Delete(ctx context.Context, resource provider.Resource) error {
+ return nil
+}
+
+func TestPlanner(t *testing.T) {
+ t.Run("CreatePlan", func(t *testing.T) {
+ planner := &Planner{
+ providers: make(map[string]provider.Provider),
+ }
+
+ mockAWSProvider := &mockProvider{name: "aws"}
+ planner.RegisterProvider(mockAWSProvider)
+
+ config := map[string]interface{}{
+ "provider": "aws",
+ "resources": []map[string]interface{}{
+ {
+ "type": "instance",
+ "name": "test-instance",
+ },
+ },
+ }
+
+ plan, err := planner.CreatePlan(context.Background(), config)
+ if err != nil {
+ t.Fatalf("Failed to create plan: %v", err)
+ }
+
+ if plan == nil {
+ t.Error("Expected plan to not be nil")
+ }
+ })
+}
diff --git a/internal/iac/provider/provider.go b/internal/iac/provider/provider.go
index 5f3ead0..afd3a25 100644
--- a/internal/iac/provider/provider.go
+++ b/internal/iac/provider/provider.go
@@ -2,18 +2,27 @@ package provider
import (
"context"
+
+ "github.com/rebelopsio/duet/pkg/types"
)
-type Resource interface {
- ID() string
- Type() string
- Metadata() map[string]interface{}
-}
+// Resource is an alias for types.Resource to maintain package consistency
+type Resource = types.Resource
+// Provider defines the interface for infrastructure providers
type Provider interface {
+ // Name returns the provider's name
Name() string
+
+ // Create creates a new resource
Create(ctx context.Context, resourceType string, config map[string]interface{}) (Resource, error)
+
+ // Read retrieves an existing resource
Read(ctx context.Context, resourceType string, id string) (Resource, error)
+
+ // Update updates an existing resource
Update(ctx context.Context, resource Resource, config map[string]interface{}) error
+
+ // Delete removes an existing resource
Delete(ctx context.Context, resource Resource) error
}