From cd484d121ed8594b4793341a06b21bfe99932a98 Mon Sep 17 00:00:00 2001 From: RebeccaMahany Date: Mon, 30 Dec 2024 13:46:39 -0500 Subject: [PATCH] Move errgroup to its own package --- ee/errgroup/errgroup.go | 156 +++++++++++++++++++++++++ pkg/osquery/runtime/osqueryinstance.go | 138 +++++----------------- 2 files changed, 188 insertions(+), 106 deletions(-) create mode 100644 ee/errgroup/errgroup.go diff --git a/ee/errgroup/errgroup.go b/ee/errgroup/errgroup.go new file mode 100644 index 000000000..611fa9a5f --- /dev/null +++ b/ee/errgroup/errgroup.go @@ -0,0 +1,156 @@ +package errgroup + +import ( + "context" + "log/slog" + "time" + + "golang.org/x/sync/errgroup" +) + +type LoggedErrgroup struct { + errgroup *errgroup.Group + cancel context.CancelFunc + doneCtx context.Context // nolint:containedctx + slogger *slog.Logger +} + +const ( + maxShutdownGoroutineDuration = 3 * time.Second +) + +func NewLoggedErrgroup(ctx context.Context, slogger *slog.Logger) *LoggedErrgroup { + ctx, cancel := context.WithCancel(ctx) + e, doneCtx := errgroup.WithContext(ctx) + + return &LoggedErrgroup{ + errgroup: e, + cancel: cancel, + doneCtx: doneCtx, + slogger: slogger, + } +} + +// AddGoroutineToErrgroup adds the given goroutine to the errgroup, ensuring that we log its start and exit. +func (l *LoggedErrgroup) AddGoroutineToErrgroup(ctx context.Context, goroutineName string, goroutine func() error) { + l.errgroup.Go(func() error { + l.slogger.Log(ctx, slog.LevelInfo, + "starting goroutine in errgroup", + "goroutine_name", goroutineName, + ) + + goroutineStart := time.Now() + err := goroutine() + elapsedTime := time.Since(goroutineStart) + + l.slogger.Log(ctx, slog.LevelInfo, + "exiting goroutine in errgroup", + "goroutine_name", goroutineName, + "goroutine_run_time", elapsedTime.String(), + "goroutine_err", err, + ) + + return err + }) +} + +// AddRepeatedGoroutineToErrgroup adds the given goroutine to the errgroup, ensuring that we log its start and exit. +// If the delay is non-zero, the goroutine will not start until after the delay interval has elapsed. The goroutine +// will run on the given interval, and will continue to run until it returns an error or the errgroup shuts down. +func (l *LoggedErrgroup) AddRepeatedGoroutineToErrgroup(ctx context.Context, goroutineName string, interval time.Duration, delay time.Duration, goroutine func() error) { + l.errgroup.Go(func() error { + l.slogger.Log(ctx, slog.LevelInfo, + "starting repeated goroutine in errgroup", + "goroutine_name", goroutineName, + "goroutine_interval", interval.String(), + "goroutine_start_delay", delay.String(), + ) + + if delay != 0*time.Second { + select { + case <-time.After(delay): + l.slogger.Log(ctx, slog.LevelDebug, + "exiting delay before starting repeated goroutine", + "goroutine_name", goroutineName, + ) + case <-l.doneCtx.Done(): + return nil + } + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-l.doneCtx.Done(): + l.slogger.Log(ctx, slog.LevelInfo, + "exiting repeated goroutine in errgroup", + "goroutine_name", goroutineName, + ) + return nil + case <-ticker.C: + goroutineStart := time.Now() + err := goroutine() + elapsedTime := time.Since(goroutineStart) + + if err != nil { + l.slogger.Log(ctx, slog.LevelInfo, + "exiting repeated goroutine in errgroup", + "goroutine_name", goroutineName, + "goroutine_run_time", elapsedTime.String(), + "goroutine_err", err, + ) + return err + } + } + } + }) +} + +// AddShutdownGoroutineToErrgroup adds the given goroutine to the errgroup, ensuring that we log its start and exit. +// The goroutine will not execute until the errgroup has received a signal to exit. +func (l *LoggedErrgroup) AddShutdownGoroutineToErrgroup(ctx context.Context, goroutineName string, goroutine func() error) { + l.errgroup.Go(func() error { + // Wait for errgroup to exit + <-l.doneCtx.Done() + + l.slogger.Log(ctx, slog.LevelInfo, + "starting shutdown goroutine in errgroup", + "goroutine_name", goroutineName, + ) + + goroutineStart := time.Now() + err := goroutine() + elapsedTime := time.Since(goroutineStart) + + logLevel := slog.LevelInfo + if elapsedTime > maxShutdownGoroutineDuration { + logLevel = slog.LevelWarn + } + + l.slogger.Log(ctx, logLevel, + "exiting shutdown goroutine in errgroup", + "goroutine_name", goroutineName, + "goroutine_run_time", elapsedTime.String(), + "goroutine_err", err, + ) + + // We don't want to actually return the error here, to avoid causing an otherwise successful call + // to `Shutdown` => `Wait` to return an error. Shutdown routine errors don't matter for the success + // of the errgroup overall. + return l.doneCtx.Err() + }) +} + +func (l *LoggedErrgroup) Shutdown() { + l.cancel() +} + +func (l *LoggedErrgroup) Wait() error { + return l.errgroup.Wait() +} + +func (l *LoggedErrgroup) Exited() <-chan struct{} { + return l.doneCtx.Done() +} diff --git a/pkg/osquery/runtime/osqueryinstance.go b/pkg/osquery/runtime/osqueryinstance.go index 0ecb3ad4d..1173ed201 100644 --- a/pkg/osquery/runtime/osqueryinstance.go +++ b/pkg/osquery/runtime/osqueryinstance.go @@ -17,6 +17,7 @@ import ( "github.com/kolide/kit/ulid" "github.com/kolide/launcher/ee/agent/types" + "github.com/kolide/launcher/ee/errgroup" "github.com/kolide/launcher/ee/gowrapper" kolidelog "github.com/kolide/launcher/ee/log/osquerylogs" "github.com/kolide/launcher/pkg/backoff" @@ -31,8 +32,6 @@ import ( osquerylogger "github.com/osquery/osquery-go/plugin/logger" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - - "golang.org/x/sync/errgroup" ) const ( @@ -91,10 +90,8 @@ type OsqueryInstance struct { // the following are instance artifacts that are created and held as a result // of launching an osqueryd process runId string // string identifier for this instance - errgroup *errgroup.Group + errgroup *errgroup.LoggedErrgroup saasExtension *launcherosq.Extension - doneCtx context.Context // nolint:containedctx - cancel context.CancelFunc cmd *exec.Cmd emsLock sync.RWMutex // Lock for extensionManagerServers extensionManagerServers []*osquery.ExtensionManagerServer @@ -184,9 +181,7 @@ func newInstance(registrationId string, knapsack types.Knapsack, serviceClient s opt(i) } - ctx, cancel := context.WithCancel(context.Background()) - i.cancel = cancel - i.errgroup, i.doneCtx = errgroup.WithContext(ctx) + i.errgroup = errgroup.NewLoggedErrgroup(context.Background(), i.slogger) i.startFunc = func(cmd *exec.Cmd) error { return cmd.Start() @@ -200,7 +195,7 @@ func (i *OsqueryInstance) BeginShutdown() { i.slogger.Log(context.TODO(), slog.LevelInfo, "instance shutdown requested", ) - i.cancel() + i.errgroup.Shutdown() } // WaitShutdown waits for the instance's errgroup routines to exit, then returns the @@ -226,7 +221,7 @@ func (i *OsqueryInstance) WaitShutdown() error { // Exited returns a channel to monitor for signal that instance has shut itself down func (i *OsqueryInstance) Exited() <-chan struct{} { - return i.doneCtx.Done() + return i.errgroup.Exited() } // Launch starts the osquery instance and its components. It will run until one of its @@ -355,7 +350,7 @@ func (i *OsqueryInstance) Launch() error { // This loop runs in the background when the process was // successfully started. ("successful" is independent of exit // code. eg: this runs if we could exec. Failure to exec is above.) - i.addGoroutineToErrgroup(ctx, "monitor_osquery_process", func() error { + i.errgroup.AddGoroutineToErrgroup(ctx, "monitor_osquery_process", func() error { err := i.cmd.Wait() switch { case err == nil, isExitOk(err): @@ -378,7 +373,7 @@ func (i *OsqueryInstance) Launch() error { }) // Kill osquery process on shutdown - i.addShutdownGoroutineToErrgroup(ctx, "kill_osquery_process", func() error { + i.errgroup.AddShutdownGoroutineToErrgroup(ctx, "kill_osquery_process", func() error { if i.cmd.Process != nil { // kill osqueryd and children if err := killProcessGroup(i.cmd); err != nil { @@ -387,14 +382,11 @@ func (i *OsqueryInstance) Launch() error { "tried to stop osquery, but process already gone", ) } else { - i.slogger.Log(ctx, slog.LevelWarn, - "error killing osquery process", - "err", err, - ) + return fmt.Errorf("killing osquery process: %w", err) } } } - return i.doneCtx.Err() + return nil }) // Start an extension manager for the extensions that osquery @@ -434,43 +426,22 @@ func (i *OsqueryInstance) Launch() error { } // Health check on interval - i.addGoroutineToErrgroup(ctx, "healthcheck", func() error { - if i.knapsack.OsqueryHealthcheckStartupDelay() != 0*time.Second { - i.slogger.Log(ctx, slog.LevelDebug, - "entering delay before starting osquery healthchecks", - ) - select { - case <-time.After(i.knapsack.OsqueryHealthcheckStartupDelay()): - i.slogger.Log(ctx, slog.LevelDebug, - "exiting delay before starting osquery healthchecks", - ) - case <-i.doneCtx.Done(): - return i.doneCtx.Err() - } + i.errgroup.AddRepeatedGoroutineToErrgroup(ctx, "healthcheck", healthCheckInterval, i.knapsack.OsqueryHealthcheckStartupDelay(), func() error { + // If device is sleeping, we do not want to perform unnecessary healthchecks that + // may force an unnecessary restart. + if i.knapsack != nil && i.knapsack.InModernStandby() { + return nil } - ticker := time.NewTicker(healthCheckInterval) - defer ticker.Stop() - for { - select { - case <-i.doneCtx.Done(): - return i.doneCtx.Err() - case <-ticker.C: - // If device is sleeping, we do not want to perform unnecessary healthchecks that - // may force an unnecessary restart. - if i.knapsack != nil && i.knapsack.InModernStandby() { - break - } - - if err := i.healthcheckWithRetries(ctx, 5, 1*time.Second); err != nil { - return fmt.Errorf("health check failed: %w", err) - } - } + if err := i.healthcheckWithRetries(ctx, 5, 1*time.Second); err != nil { + return fmt.Errorf("health check failed: %w", err) } + + return nil }) // Clean up PID file on shutdown - i.addShutdownGoroutineToErrgroup(ctx, "remove_pid_file", func() error { + i.errgroup.AddShutdownGoroutineToErrgroup(ctx, "remove_pid_file", func() error { // We do a couple retries -- on Windows, the PID file may still be in use // and therefore unable to be removed. if err := backoff.WaitFor(func() error { @@ -479,17 +450,13 @@ func (i *OsqueryInstance) Launch() error { } return nil }, 5*time.Second, 500*time.Millisecond); err != nil { - i.slogger.Log(ctx, slog.LevelInfo, - "could not remove PID file, despite retries", - "pid_file", paths.pidfilePath, - "err", err, - ) + return fmt.Errorf("removing PID file %s failed with retries: %w", paths.pidfilePath, err) } - return i.doneCtx.Err() + return nil }) // Clean up socket file on shutdown - i.addShutdownGoroutineToErrgroup(ctx, "remove_socket_file", func() error { + i.errgroup.AddShutdownGoroutineToErrgroup(ctx, "remove_socket_file", func() error { // We do a couple retries -- on Windows, the socket file may still be in use // and therefore unable to be removed. if err := backoff.WaitFor(func() error { @@ -498,13 +465,9 @@ func (i *OsqueryInstance) Launch() error { } return nil }, 5*time.Second, 500*time.Millisecond); err != nil { - i.slogger.Log(ctx, slog.LevelInfo, - "could not remove socket file, despite retries", - "socket_file", paths.extensionSocketPath, - "err", err, - ) + return fmt.Errorf("removing socket file %s failed with retries: %w", paths.extensionSocketPath, err) } - return i.doneCtx.Err() + return nil }) return nil @@ -599,7 +562,7 @@ func (i *OsqueryInstance) startKolideSaasExtension(ctx context.Context) error { }, func(r any) {}) // Run extension - i.addGoroutineToErrgroup(ctx, "saas_extension_execute", func() error { + i.errgroup.AddGoroutineToErrgroup(ctx, "saas_extension_execute", func() error { if err := i.saasExtension.Execute(); err != nil { return fmt.Errorf("kolide_grpc extension returned error: %w", err) } @@ -607,52 +570,14 @@ func (i *OsqueryInstance) startKolideSaasExtension(ctx context.Context) error { }) // Register shutdown group for extension - i.addShutdownGoroutineToErrgroup(ctx, "saas_extension_cleanup", func() error { - i.saasExtension.Shutdown(i.doneCtx.Err()) - return i.doneCtx.Err() + i.errgroup.AddShutdownGoroutineToErrgroup(ctx, "saas_extension_cleanup", func() error { + i.saasExtension.Shutdown(nil) + return nil }) return nil } -// addGoroutineToErrgroup adds the given goroutine to the errgroup, ensuring that we log its start and exit. -func (i *OsqueryInstance) addGoroutineToErrgroup(ctx context.Context, goroutineName string, goroutine func() error) { - i.errgroup.Go(func() error { - defer i.slogger.Log(ctx, slog.LevelInfo, - "exiting goroutine in errgroup", - "goroutine_name", goroutineName, - ) - - i.slogger.Log(ctx, slog.LevelInfo, - "starting goroutine in errgroup", - "goroutine_name", goroutineName, - ) - - return goroutine() - }) -} - -// addShutdownGoroutineToErrgroup adds the given goroutine to the errgroup, ensuring that we log its start and exit. -// The goroutine will not execute until the instance has received a signal to exit. -func (i *OsqueryInstance) addShutdownGoroutineToErrgroup(ctx context.Context, goroutineName string, goroutine func() error) { - i.errgroup.Go(func() error { - defer i.slogger.Log(ctx, slog.LevelInfo, - "exiting shutdown goroutine in errgroup", - "goroutine_name", goroutineName, - ) - - // Wait for errgroup to exit - <-i.doneCtx.Done() - - i.slogger.Log(ctx, slog.LevelInfo, - "starting shutdown goroutine in errgroup", - "goroutine_name", goroutineName, - ) - - return goroutine() - }) -} - // osqueryFilePaths is a struct which contains the relevant file paths needed to // launch an osqueryd instance. type osqueryFilePaths struct { @@ -857,7 +782,7 @@ func (i *OsqueryInstance) StartOsqueryExtensionManagerServer(name string, socket i.extensionManagerServers = append(i.extensionManagerServers, extensionManagerServer) // Start! - i.addGoroutineToErrgroup(context.TODO(), name, func() error { + i.errgroup.AddGoroutineToErrgroup(context.TODO(), name, func() error { if err := extensionManagerServer.Start(); err != nil { i.slogger.Log(context.TODO(), slog.LevelInfo, "extension manager server startup got error", @@ -871,15 +796,16 @@ func (i *OsqueryInstance) StartOsqueryExtensionManagerServer(name string, socket }) // register a shutdown routine - i.addShutdownGoroutineToErrgroup(context.TODO(), fmt.Sprintf("%s_cleanup", name), func() error { + i.errgroup.AddShutdownGoroutineToErrgroup(context.TODO(), fmt.Sprintf("%s_cleanup", name), func() error { if err := extensionManagerServer.Shutdown(context.TODO()); err != nil { + // Log error, but no need to bubble it up further i.slogger.Log(context.TODO(), slog.LevelInfo, "got error while shutting down extension server", "err", err, "extension_name", name, ) } - return i.doneCtx.Err() + return nil }) return nil