diff --git a/.github/workflows/go-build-and-test.yml b/.github/workflows/go-build-and-test.yml index 2466d45..b1445e0 100644 --- a/.github/workflows/go-build-and-test.yml +++ b/.github/workflows/go-build-and-test.yml @@ -23,7 +23,7 @@ jobs: cd workloadagentplatform # this is the hash of the workloadagentplatform submodule # get the hash by running: go list -m -json github.com/GoogleCloudPlatform/workloadagentplatform@main - git checkout f54e4d15148c638bd3cd1748cf21bfbc88a47593 + git checkout 27b82f043534ca3dd5965f7ebd92b5317e19ecf6 cd .. find workloadagentplatform/sharedprotos -type f -exec sed -i 's|"sharedprotos|"workloadagentplatform/sharedprotos|g' {} + env: diff --git a/.github/workflows/go-build-protos.yml b/.github/workflows/go-build-protos.yml index 8ff650e..7ac803e 100644 --- a/.github/workflows/go-build-protos.yml +++ b/.github/workflows/go-build-protos.yml @@ -31,7 +31,7 @@ jobs: cd workloadagentplatform # this is the hash of the workloadagentplatform submodule # get the hash by running: go list -m -json github.com/GoogleCloudPlatform/workloadagentplatform@main - git checkout f54e4d15148c638bd3cd1748cf21bfbc88a47593 + git checkout 27b82f043534ca3dd5965f7ebd92b5317e19ecf6 cd .. find workloadagentplatform/sharedprotos -type f -exec sed -i 's|"sharedprotos|"workloadagentplatform/sharedprotos|g' {} + env: diff --git a/build.sh b/build.sh index 3adf3b0..ae843aa 100755 --- a/build.sh +++ b/build.sh @@ -56,7 +56,7 @@ if [ "${COMPILE_PROTOS}" == "TRUE" ] && [ ! -d "workloadagentplatform" ]; then cd workloadagentplatform # this is the hash of the workloadagentplatform submodule # get the hash by running: go list -m -json github.com/GoogleCloudPlatform/workloadagentplatform@main - git checkout f54e4d15148c638bd3cd1748cf21bfbc88a47593 + git checkout 27b82f043534ca3dd5965f7ebd92b5317e19ecf6 cd .. # replace the proto imports in the platform that reference the platform find workloadagentplatform/sharedprotos -type f -exec sed -i 's|"sharedprotos|"workloadagentplatform/sharedprotos|g' {} + diff --git a/go.mod b/go.mod index d2a0f46..ea3b092 100644 --- a/go.mod +++ b/go.mod @@ -13,10 +13,10 @@ require ( github.com/DATA-DOG/go-sqlmock v1.5.0 // Get the version by running: // go list -m -json github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries@main - github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries v0.0.0-20251121062745-f54e4d15148c + github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries v0.0.0-20260109223652-27b82f043534 // Get the version by running: // go list -m -json github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos@main - github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos v0.0.0-20251121062745-f54e4d15148c + github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos v0.0.0-20260109223652-27b82f043534 github.com/StackExchange/wmi v1.2.1 github.com/cenkalti/backoff/v4 v4.3.0 github.com/gammazero/workerpool v1.1.3 diff --git a/go.sum b/go.sum index dd522b5..630e5fb 100644 --- a/go.sum +++ b/go.sum @@ -31,10 +31,10 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20O github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/GoogleCloudPlatform/agentcommunication_client v0.0.0-20250227185639-b70667e4a927 h1:nn31d5gg+ysSNqWTqSOxsKBj17GJZBqsBx7biZAgYtI= github.com/GoogleCloudPlatform/agentcommunication_client v0.0.0-20250227185639-b70667e4a927/go.mod h1:A1V05o309ZvTwy/FTBooYvvIhzM6mtsJcHJsAkeuAAM= -github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries v0.0.0-20251121062745-f54e4d15148c h1:AHOi70kCuubweFKUVWpGgfCTtDC7MoMkAGjRCLWDi5g= -github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries v0.0.0-20251121062745-f54e4d15148c/go.mod h1:WrwTr9HFkp+nbHba/tv36eCLkjKAeCx7mqiG/wL3PNU= -github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos v0.0.0-20251121062745-f54e4d15148c h1:K74cWj8vfONBD1bxWhQ/WLFc+oxDhRvf2RXhXrhOJ9o= -github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos v0.0.0-20251121062745-f54e4d15148c/go.mod h1:8Ea8vdBuPsWhhwzL9sNK7BFQE9qbkPLZUHxcucWHXaM= +github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries v0.0.0-20260109223652-27b82f043534 h1:aNKg05es2YNDllzWmWmB6Yx9BYA9u3fouDrH/aarfZI= +github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries v0.0.0-20260109223652-27b82f043534/go.mod h1:WrwTr9HFkp+nbHba/tv36eCLkjKAeCx7mqiG/wL3PNU= +github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos v0.0.0-20260109223652-27b82f043534 h1:p/dzu0l6Xnz+zK/bTcCrc2WFQFkxC36QMapIs4MMHm4= +github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos v0.0.0-20260109223652-27b82f043534/go.mod h1:8Ea8vdBuPsWhhwzL9sNK7BFQE9qbkPLZUHxcucWHXaM= github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA= github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= diff --git a/internal/daemon/oracle/oracle.go b/internal/daemon/oracle/oracle.go index 6613dd1..047219a 100644 --- a/internal/daemon/oracle/oracle.go +++ b/internal/daemon/oracle/oracle.go @@ -19,6 +19,7 @@ package oracle import ( "context" + "fmt" "runtime" "sync" "time" @@ -34,14 +35,27 @@ import ( "github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/recovery" cpb "github.com/GoogleCloudPlatform/workloadagent/protos/configuration" + gapb "github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos/guestactions" ) const ( // This is an ephemeral channel, meaning ACLs/quota are managed within the instance's project // where the agent is running, unlike registered channels that use a producer project. - defaultChannel = "oracle-operations-ephemeral-channel" + defaultChannel = "oracle-operations-ephemeral-channel" + defaultLockTimeout = 24 * time.Hour ) +// guestActionsManager is an interface satisfied by guestactions.GuestActions +// to allow for mocking in tests. +type guestActionsManager interface { + Start(context.Context, any) +} + +// newGuestActionsManager allows overriding the GuestActions implementation for testing. +var newGuestActionsManager = func() guestActionsManager { + return &guestactions.GuestActions{} +} + func convertCloudProperties(cp *cpb.CloudProperties) *metadataserver.CloudProperties { if cp == nil { return nil @@ -137,24 +151,10 @@ func (s *Service) Start(ctx context.Context, a any) { s.metricCollectionRoutine.StartRoutine(mcCtx) } - oracleHandler := oraclehandlers.New() - handlers := map[string]guestactions.GuestActionHandler{ - "oracle_run_discovery": oracleHandler.RunDiscovery, - "oracle_stop_database": oracleHandler.StopDatabase, - "oracle_disable_autostart": oracleHandler.DisableAutostart, - "oracle_start_database": oracleHandler.StartDatabase, - "oracle_run_datapatch": oracleHandler.RunDatapatch, - "oracle_disable_restricted_mode": oracleHandler.DisableRestrictedMode, - "oracle_start_listener": oracleHandler.StartListener, - "oracle_enable_autostart": oracleHandler.EnableAutostart, - "oracle_health_check": oracleHandler.HealthCheck, - "oracle_data_guard_switchover": oracleHandler.DataGuardSwitchover, - } - gaCtx := log.SetCtx(ctx, "context", "OracleGuestActions") guestActionsRoutine := &recovery.RecoverableRoutine{ Routine: runGuestActions, - RoutineArg: runGuestActionsArgs{s: s, handlers: handlers}, + RoutineArg: runGuestActionsArgs{s: s, handlers: guestActionHandlers()}, ErrorCode: usagemetrics.GuestActionsFailure, UsageLogger: *usagemetrics.UsageLogger, ExpectedMinDuration: 10 * time.Second, @@ -169,6 +169,23 @@ func (s *Service) Start(ctx context.Context, a any) { } } +func guestActionHandlers() map[string]guestactions.GuestActionHandler { + return map[string]guestactions.GuestActionHandler{ + // go/keep-sorted start + "oracle_data_guard_switchover": oraclehandlers.DataGuardSwitchover, + "oracle_disable_autostart": oraclehandlers.DisableAutostart, + "oracle_disable_restricted_mode": oraclehandlers.DisableRestrictedMode, + "oracle_enable_autostart": oraclehandlers.EnableAutostart, + "oracle_health_check": oraclehandlers.HealthCheck, + "oracle_run_datapatch": oraclehandlers.RunDatapatch, + "oracle_run_discovery": oraclehandlers.RunDiscovery, + "oracle_start_database": oraclehandlers.StartDatabase, + "oracle_start_listener": oraclehandlers.StartListener, + "oracle_stop_database": oraclehandlers.StopDatabase, + // go/keep-sorted end + } +} + func runGuestActions(ctx context.Context, a any) { log.CtxLogger(ctx).Infow("Starting guest actions listener", "channel_id", defaultChannel) args, ok := a.(runGuestActionsArgs) @@ -176,15 +193,33 @@ func runGuestActions(ctx context.Context, a any) { log.CtxLogger(ctx).Error("args is not of type runGuestActionsArgs") return } - ga := &guestactions.GuestActions{} + ga := newGuestActionsManager() + gaOpts := guestactions.Options{ - Channel: defaultChannel, - CloudProperties: convertCloudProperties(args.s.CloudProps), - Handlers: args.handlers, + Channel: defaultChannel, + CloudProperties: convertCloudProperties(args.s.CloudProps), + LROHandlers: args.handlers, + CommandConcurrencyKey: oracleCommandKey, } ga.Start(ctx, gaOpts) } +// oracleCommandKey returns the locking key and timeout for a given command and whether the command +// should be locked. +// The locking key is formed by ORACLE_SID and ORACLE_HOME. +func oracleCommandKey(cmd *gapb.Command) (key string, timeout time.Duration, lock bool) { + params := cmd.GetAgentCommand().GetParameters() + sid, sidOk := params["oracle_sid"] + home, homeOk := params["oracle_home"] + + if !sidOk || !homeOk { + // Cannot form a unique key, don't lock + return "", 0, false + } + // TODO: Add GCE Instance ID to the key to make it unique across different GCE instances. + return fmt.Sprintf("%s:%s", sid, home), defaultLockTimeout, true +} + func runDiscovery(ctx context.Context, a any) { log.CtxLogger(ctx).Info("Running Oracle Discovery") args, ok := a.(runDiscoveryArgs) diff --git a/internal/daemon/oracle/oracle_test.go b/internal/daemon/oracle/oracle_test.go index 892fc85..be57ea7 100644 --- a/internal/daemon/oracle/oracle_test.go +++ b/internal/daemon/oracle/oracle_test.go @@ -17,14 +17,30 @@ limitations under the License. package oracle import ( + "context" "testing" + "time" "github.com/google/go-cmp/cmp" "google.golang.org/protobuf/testing/protocmp" cpb "github.com/GoogleCloudPlatform/workloadagent/protos/configuration" "github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/gce/metadataserver" + "github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/guestactions" + gapb "github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos/guestactions" ) +// fakeGuestActionsManager is a test double for guestActionsManager. +type fakeGuestActionsManager struct { + startCalled bool + startOpts guestactions.Options +} + +// Start captures the options passed and marks itself as called. +func (f *fakeGuestActionsManager) Start(ctx context.Context, a any) { + f.startCalled = true + f.startOpts = a.(guestactions.Options) +} + func TestConvertCloudProperties(t *testing.T) { tests := []struct { name string @@ -74,3 +90,174 @@ func TestConvertCloudProperties(t *testing.T) { }) } } + +func TestOracleCommandKey(t *testing.T) { + tests := []struct { + name string + cmd *gapb.Command + wantKey string + wantTimeout time.Duration + wantLock bool + }{ + { + name: "nil agent command", + cmd: &gapb.Command{}, + wantKey: "", + wantTimeout: 0, + wantLock: false, + }, + { + name: "nil parameters", + cmd: &gapb.Command{ + CommandType: &gapb.Command_AgentCommand{ + AgentCommand: &gapb.AgentCommand{ + Command: "oracle_start_database", + }, + }, + }, + wantKey: "", + wantTimeout: 0, + wantLock: false, + }, + { + name: "empty parameters", + cmd: &gapb.Command{ + CommandType: &gapb.Command_AgentCommand{ + AgentCommand: &gapb.AgentCommand{ + Command: "oracle_start_database", + Parameters: map[string]string{}, + }, + }, + }, + wantKey: "", + wantTimeout: 0, + wantLock: false, + }, + { + name: "missing oracle_home", + cmd: &gapb.Command{ + CommandType: &gapb.Command_AgentCommand{ + AgentCommand: &gapb.AgentCommand{ + Command: "oracle_start_database", + Parameters: map[string]string{ + "oracle_sid": "orcl", + }, + }, + }, + }, + wantKey: "", + wantTimeout: 0, + wantLock: false, + }, + { + name: "missing oracle_sid", + cmd: &gapb.Command{ + CommandType: &gapb.Command_AgentCommand{ + AgentCommand: &gapb.AgentCommand{ + Command: "oracle_start_database", + Parameters: map[string]string{ + "oracle_home": "/u01/app/oracle/product/19.3.0/dbhome_1", + }, + }, + }, + }, + wantKey: "", + wantTimeout: 0, + wantLock: false, + }, + { + name: "with oracle_sid and oracle_home", + cmd: &gapb.Command{ + CommandType: &gapb.Command_AgentCommand{ + AgentCommand: &gapb.AgentCommand{ + Command: "oracle_health_check", + Parameters: map[string]string{ + "oracle_sid": "orcl", + "oracle_home": "/u01/app/oracle/product/19.3.0/dbhome_1", + }, + }, + }, + }, + wantKey: "orcl:/u01/app/oracle/product/19.3.0/dbhome_1", + wantTimeout: 24 * time.Hour, + wantLock: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gotKey, gotTimeout, gotLock := oracleCommandKey(tc.cmd) + if gotKey != tc.wantKey || gotTimeout != tc.wantTimeout || gotLock != tc.wantLock { + t.Errorf("oracleCommandKey(%v) = (%q, %v, %v), want (%q, %v, %v)", tc.cmd, gotKey, gotTimeout, gotLock, tc.wantKey, tc.wantTimeout, tc.wantLock) + } + }) + } +} + +func TestRunGuestActions(t *testing.T) { + // Keep track of the original newGuestActionsManager and restore it after the test. + originalNewGuestActionsManager := newGuestActionsManager + defer func() { + newGuestActionsManager = originalNewGuestActionsManager + }() + + fakeGA := &fakeGuestActionsManager{} + newGuestActionsManager = func() guestActionsManager { + return fakeGA + } + + cloudProps := &cpb.CloudProperties{ProjectId: "test-project"} + handlers := map[string]guestactions.GuestActionHandler{ + "test_handler": func(ctx context.Context, cmd *gapb.Command, cp *metadataserver.CloudProperties) *gapb.CommandResult { + return nil + }, + } + args := runGuestActionsArgs{ + s: &Service{ + CloudProps: cloudProps, + }, + handlers: handlers, + } + + runGuestActions(context.Background(), args) + + if !fakeGA.startCalled { + t.Errorf("runGuestActions() did not call Start()") + } + + if fakeGA.startOpts.Channel != defaultChannel { + t.Errorf("runGuestActions() called Start() with channel %q, want %q", fakeGA.startOpts.Channel, defaultChannel) + } + if fakeGA.startOpts.CloudProperties.ProjectID != cloudProps.ProjectId { + t.Errorf("runGuestActions() called Start() with CloudProperties.ProjectID %q, want %q", fakeGA.startOpts.CloudProperties.ProjectID, cloudProps.ProjectId) + } + if fakeGA.startOpts.CommandConcurrencyKey == nil { + t.Errorf("runGuestActions() called Start() with nil CommandConcurrencyKey, want non-nil") + } +} + +func TestGuestActionHandlers(t *testing.T) { + handlers := guestActionHandlers() + expectedHandlers := []string{ + "oracle_data_guard_switchover", + "oracle_disable_autostart", + "oracle_disable_restricted_mode", + "oracle_enable_autostart", + "oracle_health_check", + "oracle_run_datapatch", + "oracle_run_discovery", + "oracle_start_database", + "oracle_start_listener", + "oracle_stop_database", + } + + if len(handlers) != len(expectedHandlers) { + t.Errorf("getGuestActionHandlers() returned %d handlers, want %d", len(handlers), len(expectedHandlers)) + } + + for _, h := range expectedHandlers { + if _, ok := handlers[h]; !ok { + t.Errorf("getGuestActionHandlers() missing handler for %q", h) + } + } +} diff --git a/internal/oraclehandlers/healthcheck.go b/internal/oraclehandlers/healthcheck.go index 0816a7d..56f786a 100644 --- a/internal/oraclehandlers/healthcheck.go +++ b/internal/oraclehandlers/healthcheck.go @@ -34,7 +34,7 @@ var errorRegex = regexp.MustCompile(`ORA-\d+`) const healthCheckTimeoutSeconds = 15 // HealthCheck checks if the database is healthy by executing a simple SQL query. -func (h *OracleHandler) HealthCheck(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { +func HealthCheck(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { params := command.GetAgentCommand().GetParameters() logger := log.CtxLogger(ctx) if result := validateParams(ctx, logger, command, params); result != nil { @@ -42,11 +42,6 @@ func (h *OracleHandler) HealthCheck(ctx context.Context, command *gpb.Command, c } logger = logger.With("oracle_sid", params["oracle_sid"], "oracle_home", params["oracle_home"], "oracle_user", params["oracle_user"]) logger.Debugw("oracle_health_check handler called") - unlock, result := h.lockDatabase(ctx, logger, command) - if result != nil { - return result - } - defer unlock() stdout, stderr, err := healthCheck(ctx, logger, params) if err != nil { diff --git a/internal/oraclehandlers/healthcheck_test.go b/internal/oraclehandlers/healthcheck_test.go index fbcdd7d..48a5743 100644 --- a/internal/oraclehandlers/healthcheck_test.go +++ b/internal/oraclehandlers/healthcheck_test.go @@ -19,7 +19,6 @@ package oraclehandlers import ( "context" "fmt" - "strings" "testing" "google.golang.org/protobuf/proto" @@ -84,7 +83,6 @@ func TestHealthCheck(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { origRunSQL := runSQL - h := New() defer func() { runSQL = origRunSQL }() runSQL = createMockRunSQL(tc.sqlQueries) @@ -96,7 +94,7 @@ func TestHealthCheck(t *testing.T) { }, }, } - result := h.HealthCheck(context.Background(), command, nil) + result := HealthCheck(context.Background(), command, nil) s := &spb.Status{} if err := anypb.UnmarshalTo(result.Payload, s, proto.UnmarshalOptions{}); err != nil { @@ -108,49 +106,3 @@ func TestHealthCheck(t *testing.T) { }) } } - -func TestHealthCheckLocked(t *testing.T) { - h := New() - sid := "locked_sid" - params := map[string]string{ - "oracle_sid": sid, - "oracle_home": "/u01/app/oracle/product/19.3.0/dbhome_1", - "oracle_user": "oracle", - } - command := &gpb.Command{ - CommandType: &gpb.Command_AgentCommand{ - AgentCommand: &gpb.AgentCommand{ - Command: "oracle_health_check", - Parameters: params, - }, - }, - } - - origRunSQL := runSQL - defer func() { runSQL = origRunSQL }() - - runSQLBlocked := make(chan bool) - unblockRunSQL := make(chan bool) - - runSQL = func(ctx context.Context, params map[string]string, query string, timeout int) (string, string, error) { - runSQLBlocked <- true - <-unblockRunSQL - return "1", "", nil - } - - go h.HealthCheck(context.Background(), command, nil) - <-runSQLBlocked - result := h.HealthCheck(context.Background(), command, nil) - unblockRunSQL <- true - - s := &spb.Status{} - if err := anypb.UnmarshalTo(result.Payload, s, proto.UnmarshalOptions{}); err != nil { - t.Fatalf("Failed to unmarshal payload: %v", err) - } - if s.Code != int32(codepb.Code_ABORTED) { - t.Errorf("HealthCheck() with params %v returned error code %d, want %d", params, s.Code, codepb.Code_ABORTED) - } - if !strings.Contains(s.Message, "oracle_health_check") { - t.Errorf("HealthCheck() with params %v returned %q, want %q", params, s.Message, "oracle_health_check") - } -} diff --git a/internal/oraclehandlers/oraclehandlers.go b/internal/oraclehandlers/oraclehandlers.go index 5ed1214..7531412 100644 --- a/internal/oraclehandlers/oraclehandlers.go +++ b/internal/oraclehandlers/oraclehandlers.go @@ -21,7 +21,6 @@ import ( "context" "errors" "fmt" - "sync" "go.uber.org/zap" @@ -31,62 +30,6 @@ import ( gpb "github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos/guestactions" ) -// OracleHandler holds shared state for Oracle operations like in-process locking, -// in order to prevent concurrent operations on the same database instance (SID). -type OracleHandler struct { - opMu sync.Mutex - // activeOps tracks the oracle_sids that are currently being operated on (sid -> commandName). - activeOps map[string]string -} - -// New creates a new OracleHandler instance. -func New() *OracleHandler { - return &OracleHandler{ - activeOps: make(map[string]string), - } -} - -// lockDatabase implements an in-memory lock to prevent concurrent agent operations -// (e.g., oracle_start_database, oracle_stop_database) on the same Oracle SID. -// It uses a per-SID locking mechanism to allow concurrent operations on different -// database SIDs, but preventing multiple operations on the same SID simultaneously. -// If an operation is already in progress for the SID specified in the command, -// lockDatabase returns an ABORTED result immediately. -// If no operation is in progress, it registers the current operation as active -// and returns: -// 1. A cleanup function to release the lock, which should be deferred by the caller. -// 2. A nil CommandResult. -func (h *OracleHandler) lockDatabase(ctx context.Context, logger *zap.SugaredLogger, command *gpb.Command) (func(), *gpb.CommandResult) { - params := command.GetAgentCommand().GetParameters() - commandName := command.GetAgentCommand().GetCommand() - - h.opMu.Lock() - if op, ok := h.activeOps[params["oracle_sid"]]; ok { - h.opMu.Unlock() - logger.Errorw("Operation already in progress") - anyPayload, err := anypb.New(&spb.Status{ - Code: int32(codepb.Code_ABORTED), - Message: fmt.Sprintf("operation %q already in progress for database %s", op, params["oracle_sid"]), - }) - if err != nil { - logger.Warnw("Failed to pack payload", "error", err) - } - return func() {}, &gpb.CommandResult{ - Command: command, - ExitCode: 1, // Signal guestactions framework that the command failed. - Payload: anyPayload, - } - } - h.activeOps[params["oracle_sid"]] = commandName - h.opMu.Unlock() - - return func() { - h.opMu.Lock() - delete(h.activeOps, params["oracle_sid"]) - h.opMu.Unlock() - }, nil -} - // commandResult creates a gpb.CommandResult with the given status code and message packed into the payload. func commandResult(ctx context.Context, logger *zap.SugaredLogger, command *gpb.Command, stdout, stderr string, code codepb.Code, message string, execErr error) *gpb.CommandResult { anyPayload, err := anypb.New(&spb.Status{ diff --git a/internal/oraclehandlers/oraclehandlers_test.go b/internal/oraclehandlers/oraclehandlers_test.go index 917f7d3..e87e484 100644 --- a/internal/oraclehandlers/oraclehandlers_test.go +++ b/internal/oraclehandlers/oraclehandlers_test.go @@ -113,3 +113,89 @@ func TestCommandResult(t *testing.T) { }) } } + +func TestValidateParams(t *testing.T) { + tests := []struct { + name string + params map[string]string + wantErrorCode codepb.Code + }{ + { + name: "nil params", + params: nil, + wantErrorCode: codepb.Code_INVALID_ARGUMENT, + }, + { + name: "empty params", + params: map[string]string{}, + wantErrorCode: codepb.Code_INVALID_ARGUMENT, + }, + { + name: "missing oracle_sid", + params: map[string]string{ + "oracle_home": "home", + "oracle_user": "user", + }, + wantErrorCode: codepb.Code_INVALID_ARGUMENT, + }, + { + name: "empty oracle_sid", + params: map[string]string{ + "oracle_sid": "", + "oracle_home": "home", + "oracle_user": "user", + }, + wantErrorCode: codepb.Code_INVALID_ARGUMENT, + }, + { + name: "missing oracle_home", + params: map[string]string{ + "oracle_sid": "sid", + "oracle_user": "user", + }, + wantErrorCode: codepb.Code_INVALID_ARGUMENT, + }, + { + name: "missing oracle_user", + params: map[string]string{ + "oracle_sid": "sid", + "oracle_home": "home", + }, + wantErrorCode: codepb.Code_INVALID_ARGUMENT, + }, + { + name: "valid params", + params: map[string]string{ + "oracle_sid": "sid", + "oracle_home": "home", + "oracle_user": "user", + }, + wantErrorCode: codepb.Code_OK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := validateParams(context.Background(), zap.NewNop().Sugar(), nil, tc.params) + + if tc.wantErrorCode == codepb.Code_OK { + if got != nil { + t.Errorf("validateParams() returned %v, want nil", got) + } + return + } + + if got == nil { + t.Fatalf("validateParams() returned nil, want error") + } + + s := &spb.Status{} + if err := anypb.UnmarshalTo(got.Payload, s, proto.UnmarshalOptions{}); err != nil { + t.Fatalf("Failed to unmarshal payload: %v", err) + } + if s.Code != int32(tc.wantErrorCode) { + t.Errorf("validateParams() returned error code %d, want %d", s.Code, tc.wantErrorCode) + } + }) + } +} diff --git a/internal/oraclehandlers/patching.go b/internal/oraclehandlers/patching.go index bc73a93..f843fe8 100644 --- a/internal/oraclehandlers/patching.go +++ b/internal/oraclehandlers/patching.go @@ -25,7 +25,7 @@ import ( ) // DisableAutostart implements the oracle_disable_autostart guest action. -func (h *OracleHandler) DisableAutostart(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { +func DisableAutostart(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { log.CtxLogger(ctx).Info("oracle_disable_autostart handler called") // TODO: Implement oracle_disable_autostart handler. return &gpb.CommandResult{ @@ -36,7 +36,7 @@ func (h *OracleHandler) DisableAutostart(ctx context.Context, command *gpb.Comma } // RunDatapatch implements the oracle_run_datapatch guest action. -func (h *OracleHandler) RunDatapatch(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { +func RunDatapatch(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { log.CtxLogger(ctx).Info("oracle_run_datapatch handler called") // TODO: Implement oracle_run_datapatch handler. return &gpb.CommandResult{ @@ -47,7 +47,7 @@ func (h *OracleHandler) RunDatapatch(ctx context.Context, command *gpb.Command, } // DisableRestrictedMode implements the oracle_disable_restricted_mode guest action. -func (h *OracleHandler) DisableRestrictedMode(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { +func DisableRestrictedMode(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { log.CtxLogger(ctx).Info("oracle_disable_restricted_mode handler called") // TODO: Implement oracle_disable_restricted_mode handler. return &gpb.CommandResult{ @@ -58,7 +58,7 @@ func (h *OracleHandler) DisableRestrictedMode(ctx context.Context, command *gpb. } // StartListener implements the oracle_start_listener guest action. -func (h *OracleHandler) StartListener(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { +func StartListener(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { log.CtxLogger(ctx).Info("oracle_start_listener handler called") // TODO: Implement oracle_start_listener handler. return &gpb.CommandResult{ @@ -69,7 +69,7 @@ func (h *OracleHandler) StartListener(ctx context.Context, command *gpb.Command, } // EnableAutostart implements the oracle_enable_autostart guest action. -func (h *OracleHandler) EnableAutostart(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { +func EnableAutostart(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { log.CtxLogger(ctx).Info("oracle_enable_autostart handler called") // TODO: Implement oracle_enable_autostart handler. return &gpb.CommandResult{ diff --git a/internal/oraclehandlers/patching_test.go b/internal/oraclehandlers/patching_test.go new file mode 100644 index 0000000..7e1e031 --- /dev/null +++ b/internal/oraclehandlers/patching_test.go @@ -0,0 +1,110 @@ +/* +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package oraclehandlers + +import ( + "context" + "strings" + "testing" + + gpb "github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos/guestactions" +) + +func TestDisableAutostart_NotImplemented(t *testing.T) { + command := &gpb.Command{ + CommandType: &gpb.Command_AgentCommand{ + AgentCommand: &gpb.AgentCommand{ + Command: "oracle_disable_autostart", + }, + }, + } + result := DisableAutostart(context.Background(), command, nil) + if result.GetExitCode() != 1 { + t.Errorf("DisableAutostart() returned exit code %d, want 1", result.GetExitCode()) + } + if !strings.Contains(result.GetStdout(), "not implemented") { + t.Errorf("DisableAutostart() returned stdout %q, want 'not implemented'", result.GetStdout()) + } +} + +func TestRunDatapatch_NotImplemented(t *testing.T) { + command := &gpb.Command{ + CommandType: &gpb.Command_AgentCommand{ + AgentCommand: &gpb.AgentCommand{ + Command: "oracle_run_datapatch", + }, + }, + } + result := RunDatapatch(context.Background(), command, nil) + if result.GetExitCode() != 1 { + t.Errorf("RunDatapatch() returned exit code %d, want 1", result.GetExitCode()) + } + if !strings.Contains(result.GetStdout(), "not implemented") { + t.Errorf("RunDatapatch() returned stdout %q, want 'not implemented'", result.GetStdout()) + } +} + +func TestDisableRestrictedMode_NotImplemented(t *testing.T) { + command := &gpb.Command{ + CommandType: &gpb.Command_AgentCommand{ + AgentCommand: &gpb.AgentCommand{ + Command: "oracle_disable_restricted_mode", + }, + }, + } + result := DisableRestrictedMode(context.Background(), command, nil) + if result.GetExitCode() != 1 { + t.Errorf("DisableRestrictedMode() returned exit code %d, want 1", result.GetExitCode()) + } + if !strings.Contains(result.GetStdout(), "not implemented") { + t.Errorf("DisableRestrictedMode() returned stdout %q, want 'not implemented'", result.GetStdout()) + } +} + +func TestStartListener_NotImplemented(t *testing.T) { + command := &gpb.Command{ + CommandType: &gpb.Command_AgentCommand{ + AgentCommand: &gpb.AgentCommand{ + Command: "oracle_start_listener", + }, + }, + } + result := StartListener(context.Background(), command, nil) + if result.GetExitCode() != 1 { + t.Errorf("StartListener() returned exit code %d, want 1", result.GetExitCode()) + } + if !strings.Contains(result.GetStdout(), "not implemented") { + t.Errorf("StartListener() returned stdout %q, want 'not implemented'", result.GetStdout()) + } +} + +func TestEnableAutostart_NotImplemented(t *testing.T) { + command := &gpb.Command{ + CommandType: &gpb.Command_AgentCommand{ + AgentCommand: &gpb.AgentCommand{ + Command: "oracle_enable_autostart", + }, + }, + } + result := EnableAutostart(context.Background(), command, nil) + if result.GetExitCode() != 1 { + t.Errorf("EnableAutostart() returned exit code %d, want 1", result.GetExitCode()) + } + if !strings.Contains(result.GetStdout(), "not implemented") { + t.Errorf("EnableAutostart() returned stdout %q, want 'not implemented'", result.GetStdout()) + } +} diff --git a/internal/oraclehandlers/rundiscovery.go b/internal/oraclehandlers/rundiscovery.go index f65131d..cf8101d 100644 --- a/internal/oraclehandlers/rundiscovery.go +++ b/internal/oraclehandlers/rundiscovery.go @@ -25,7 +25,7 @@ import ( ) // RunDiscovery implements the oracle_run_discovery guest action. -func (h *OracleHandler) RunDiscovery(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { +func RunDiscovery(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { log.CtxLogger(ctx).Info("oracle_run_discovery handler called") // TODO: Implement oracle_run_discovery handler. return &gpb.CommandResult{ diff --git a/internal/oraclehandlers/rundiscovery_test.go b/internal/oraclehandlers/rundiscovery_test.go new file mode 100644 index 0000000..45f9486 --- /dev/null +++ b/internal/oraclehandlers/rundiscovery_test.go @@ -0,0 +1,42 @@ +/* +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package oraclehandlers + +import ( + "context" + "strings" + "testing" + + gpb "github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos/guestactions" +) + +func TestRunDiscovery_NotImplemented(t *testing.T) { + command := &gpb.Command{ + CommandType: &gpb.Command_AgentCommand{ + AgentCommand: &gpb.AgentCommand{ + Command: "oracle_run_discovery", + }, + }, + } + result := RunDiscovery(context.Background(), command, nil) + if result.GetExitCode() != 1 { + t.Errorf("RunDiscovery() returned exit code %d, want 1", result.GetExitCode()) + } + if !strings.Contains(result.GetStdout(), "not implemented") { + t.Errorf("RunDiscovery() returned stdout %q, want 'not implemented'", result.GetStdout()) + } +} diff --git a/internal/oraclehandlers/startstop.go b/internal/oraclehandlers/startstop.go index 67f8f9f..303f690 100644 --- a/internal/oraclehandlers/startstop.go +++ b/internal/oraclehandlers/startstop.go @@ -45,7 +45,7 @@ const ( // StopDatabase implements the oracle_stop_database guest action. // It attempts to shut down the database gracefully using "SHUTDOWN IMMEDIATE". // If the command fails, it returns an error in the payload. -func (h *OracleHandler) StopDatabase(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { +func StopDatabase(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { params := command.GetAgentCommand().GetParameters() logger := log.CtxLogger(ctx) if result := validateParams(ctx, logger, command, params); result != nil { @@ -54,12 +54,6 @@ func (h *OracleHandler) StopDatabase(ctx context.Context, command *gpb.Command, logger = logger.With("oracle_sid", params["oracle_sid"], "oracle_home", params["oracle_home"], "oracle_user", params["oracle_user"]) logger.Infow("oracle_stop_database handler called") - unlock, result := h.lockDatabase(ctx, logger, command) - if result != nil { - return result - } - defer unlock() - // TODO: Handle Data Guard standby databases. stdout, stderr, err := stopDatabase(ctx, logger, params) if err != nil { @@ -90,7 +84,7 @@ func stopDatabase(ctx context.Context, logger *zap.SugaredLogger, params map[str // StartDatabase implements the oracle_start_database guest action. // It checks the current status of the database and starts it if it is not already running. // If the database is already mounted, it attempts to open it. -func (h *OracleHandler) StartDatabase(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { +func StartDatabase(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { params := command.GetAgentCommand().GetParameters() logger := log.CtxLogger(ctx) if result := validateParams(ctx, logger, command, params); result != nil { @@ -99,12 +93,6 @@ func (h *OracleHandler) StartDatabase(ctx context.Context, command *gpb.Command, logger = logger.With("oracle_sid", params["oracle_sid"], "oracle_home", params["oracle_home"], "oracle_user", params["oracle_user"]) logger.Infow("oracle_start_database handler called") - unlock, result := h.lockDatabase(ctx, logger, command) - if result != nil { - return result - } - defer unlock() - // TODO: Handle Data Guard standby databases. // Cases to consider: // 1. Active Data Guard: Open read-only. diff --git a/internal/oraclehandlers/startstop_test.go b/internal/oraclehandlers/startstop_test.go index 0df8137..30f4b62 100644 --- a/internal/oraclehandlers/startstop_test.go +++ b/internal/oraclehandlers/startstop_test.go @@ -19,7 +19,6 @@ package oraclehandlers import ( "context" "fmt" - "strings" "testing" "google.golang.org/protobuf/proto" @@ -92,12 +91,23 @@ func TestStopDatabase(t *testing.T) { }, wantErrorCode: codepb.Code_OK, }, + { + name: "shutdown immediate unexpected output", + params: map[string]string{ + "oracle_sid": "orcl", + "oracle_home": "/u01/app/oracle/product/19.3.0/dbhome_1", + "oracle_user": "oracle", + }, + sqlQueries: map[string]*commandlineexecutor.Result{ + "SHUTDOWN IMMEDIATE": &commandlineexecutor.Result{StdOut: "Some unexpected output"}, + }, + wantErrorCode: codepb.Code_FAILED_PRECONDITION, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { origRunSQL := runSQL - h := New() defer func() { runSQL = origRunSQL }() runSQL = createMockRunSQL(tc.sqlQueries) @@ -109,7 +119,7 @@ func TestStopDatabase(t *testing.T) { }, }, } - result := h.StopDatabase(context.Background(), command, nil) + result := StopDatabase(context.Background(), command, nil) s := &spb.Status{} if err := anypb.UnmarshalTo(result.Payload, s, proto.UnmarshalOptions{}); err != nil { @@ -122,59 +132,6 @@ func TestStopDatabase(t *testing.T) { } } -func TestStopDatabaseLocked(t *testing.T) { - h := New() - sid := "locked_sid" - params := map[string]string{ - "oracle_sid": sid, - "oracle_home": "/u01/app/oracle/product/19.3.0/dbhome_1", - "oracle_user": "oracle", - } - command := &gpb.Command{ - CommandType: &gpb.Command_AgentCommand{ - AgentCommand: &gpb.AgentCommand{ - Command: "oracle_stop_database", - Parameters: params, - }, - }, - } - - origRunSQL := runSQL - defer func() { runSQL = origRunSQL }() - - runSQLBlocked := make(chan bool) - unblockRunSQL := make(chan bool) - - runSQL = func(ctx context.Context, params map[string]string, query string, timeout int) (string, string, error) { - runSQLBlocked <- true - <-unblockRunSQL - return shutdownSuccess, "", nil - } - - // Start the first operation in a goroutine. - go h.StopDatabase(context.Background(), command, nil) - - // Wait until the first operation has locked the DB and is blocked in runSQL. - <-runSQLBlocked - - // Attempt to start a second operation on the same SID. - result := h.StopDatabase(context.Background(), command, nil) - - // Unblock the first operation. - unblockRunSQL <- true - - s := &spb.Status{} - if err := anypb.UnmarshalTo(result.Payload, s, proto.UnmarshalOptions{}); err != nil { - t.Fatalf("Failed to unmarshal payload: %v", err) - } - if s.Code != int32(codepb.Code_ABORTED) { - t.Errorf("StopDatabase() returned error code %d, want %d", s.Code, codepb.Code_ABORTED) - } - if !strings.Contains(s.Message, `"oracle_stop_database"`) { - t.Errorf("StopDatabase() returned message %q, want it to contain %q", s.Message, `"oracle_stop_database"`) - } -} - func TestStartDatabase(t *testing.T) { tests := []struct { name string @@ -265,12 +222,36 @@ func TestStartDatabase(t *testing.T) { }, wantErrorCode: codepb.Code_FAILED_PRECONDITION, }, + { + name: "startup unexpected output", + params: map[string]string{ + "oracle_sid": "orcl", + "oracle_home": "/u01/app/oracle/product/19.3.0/dbhome_1", + "oracle_user": "oracle", + }, + sqlQueries: map[string]*commandlineexecutor.Result{ + "STARTUP": &commandlineexecutor.Result{StdOut: "Some unexpected output"}, + }, + wantErrorCode: codepb.Code_FAILED_PRECONDITION, + }, + { + name: "startup already running status check fail", + params: map[string]string{ + "oracle_sid": "orcl", + "oracle_home": "/u01/app/oracle/product/19.3.0/dbhome_1", + "oracle_user": "oracle", + }, + sqlQueries: map[string]*commandlineexecutor.Result{ + "STARTUP": &commandlineexecutor.Result{StdOut: alreadyRunning}, + "SELECT status FROM v$instance;": &commandlineexecutor.Result{ExitCode: 1, Error: fmt.Errorf("status check failed")}, + }, + wantErrorCode: codepb.Code_FAILED_PRECONDITION, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { origRunSQL := runSQL - h := New() defer func() { runSQL = origRunSQL }() runSQL = createMockRunSQL(tc.sqlQueries) @@ -282,7 +263,7 @@ func TestStartDatabase(t *testing.T) { }, }, } - result := h.StartDatabase(context.Background(), command, nil) + result := StartDatabase(context.Background(), command, nil) s := &spb.Status{} if err := anypb.UnmarshalTo(result.Payload, s, proto.UnmarshalOptions{}); err != nil { @@ -294,56 +275,3 @@ func TestStartDatabase(t *testing.T) { }) } } - -func TestStartDatabaseLocked(t *testing.T) { - h := New() - sid := "locked_sid" - params := map[string]string{ - "oracle_sid": sid, - "oracle_home": "/u01/app/oracle/product/19.3.0/dbhome_1", - "oracle_user": "oracle", - } - command := &gpb.Command{ - CommandType: &gpb.Command_AgentCommand{ - AgentCommand: &gpb.AgentCommand{ - Command: "oracle_start_database", - Parameters: params, - }, - }, - } - - origRunSQL := runSQL - defer func() { runSQL = origRunSQL }() - - runSQLBlocked := make(chan bool) - unblockRunSQL := make(chan bool) - - runSQL = func(ctx context.Context, params map[string]string, query string, timeout int) (string, string, error) { - runSQLBlocked <- true // Signal that runSQL has been reached - <-unblockRunSQL // Pause here until signaled to continue - return startupSuccess, "", nil - } - - // Start the first operation in a goroutine. - go h.StartDatabase(context.Background(), command, nil) - - // Wait until the first operation has locked the DB and is blocked in runSQL. - <-runSQLBlocked - - // Attempt to start a second operation on the same SID. - result := h.StartDatabase(context.Background(), command, nil) - - // Unblock the first operation. - unblockRunSQL <- true - - s := &spb.Status{} - if err := anypb.UnmarshalTo(result.Payload, s, proto.UnmarshalOptions{}); err != nil { - t.Fatalf("Failed to unmarshal payload: %v", err) - } - if s.Code != int32(codepb.Code_ABORTED) { - t.Errorf("StartDatabase() returned error code %d, want %d", s.Code, codepb.Code_ABORTED) - } - if !strings.Contains(s.Message, `"oracle_start_database"`) { - t.Errorf("StartDatabase() returned message %q, want it to contain %q", s.Message, `"oracle_start_database"`) - } -} diff --git a/internal/oraclehandlers/switchover.go b/internal/oraclehandlers/switchover.go index 76b6b05..a6487e7 100644 --- a/internal/oraclehandlers/switchover.go +++ b/internal/oraclehandlers/switchover.go @@ -25,7 +25,7 @@ import ( ) // DataGuardSwitchover implements the oracle_data_guard_switchover guest action. -func (h *OracleHandler) DataGuardSwitchover(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { +func DataGuardSwitchover(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { log.CtxLogger(ctx).Info("oracle_data_guard_switchover handler called") // TODO: Implement oracle_data_guard_switchover handler. return &gpb.CommandResult{ diff --git a/internal/oraclehandlers/switchover_test.go b/internal/oraclehandlers/switchover_test.go new file mode 100644 index 0000000..11131ce --- /dev/null +++ b/internal/oraclehandlers/switchover_test.go @@ -0,0 +1,42 @@ +/* +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package oraclehandlers + +import ( + "context" + "strings" + "testing" + + gpb "github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos/guestactions" +) + +func TestDataGuardSwitchover_NotImplemented(t *testing.T) { + command := &gpb.Command{ + CommandType: &gpb.Command_AgentCommand{ + AgentCommand: &gpb.AgentCommand{ + Command: "oracle_data_guard_switchover", + }, + }, + } + result := DataGuardSwitchover(context.Background(), command, nil) + if result.GetExitCode() != 1 { + t.Errorf("DataGuardSwitchover() returned exit code %d, want 1", result.GetExitCode()) + } + if !strings.Contains(result.GetStdout(), "not implemented") { + t.Errorf("DataGuardSwitchover() returned stdout %q, want 'not implemented'", result.GetStdout()) + } +}