diff --git a/internal/api/api.go b/internal/api/api.go index a09d045cb..8ec895dcc 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -9,6 +9,7 @@ import ( "github.com/sebest/xff" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/taskafter" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/hooks/hookshttp" "github.com/supabase/auth/internal/hooks/hookspgfunc" @@ -138,6 +139,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.UseBypass(api.databaseCleanup(cleanup)) } + r.UseBypass(taskafter.Middleware) + r.Get("/health", api.HealthCheck) r.Get("/.well-known/jwks.json", api.Jwks) diff --git a/internal/api/mail.go b/internal/api/mail.go index f90d9a74c..598e0e113 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -17,6 +17,7 @@ import ( "github.com/sethvargo/go-password/password" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/api/taskafter" "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" @@ -650,37 +651,42 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, EmailData: emailData, } output := v0hooks.SendEmailOutput{} - return a.hooksMgr.InvokeHook(tx, r, &input, &output) - } - mr := a.Mailer() - var err error - switch emailActionType { - case mail.SignupVerification: - err = mr.ConfirmationMail(r, u, otp, referrerURL, externalURL) - case mail.MagicLinkVerification: - err = mr.MagicLinkMail(r, u, otp, referrerURL, externalURL) - case mail.ReauthenticationVerification: - err = mr.ReauthenticateMail(r, u, otp) - case mail.RecoveryVerification: - err = mr.RecoveryMail(r, u, otp, referrerURL, externalURL) - case mail.InviteVerification: - err = mr.InviteMail(r, u, otp, referrerURL, externalURL) - case mail.EmailChangeVerification: - err = mr.EmailChangeMail(r, u, otpNew, otp, referrerURL, externalURL) - default: - err = errors.New("invalid email action type") - } - - switch { - case errors.Is(err, mail.ErrInvalidEmailAddress), - errors.Is(err, mail.ErrInvalidEmailFormat), - errors.Is(err, mail.ErrInvalidEmailDNS): - return apierrors.NewBadRequestError( - apierrors.ErrorCodeEmailAddressInvalid, - "Email address %q is invalid", - u.GetEmail()) - default: - return err + return taskafter.Queue(ctx, func() error { + return a.hooksMgr.InvokeHook(tx, r, &input, &output) + }) } + + return taskafter.Queue(ctx, func() error { + mr := a.Mailer() + var err error + switch emailActionType { + case mail.SignupVerification: + err = mr.ConfirmationMail(r, u, otp, referrerURL, externalURL) + case mail.MagicLinkVerification: + err = mr.MagicLinkMail(r, u, otp, referrerURL, externalURL) + case mail.ReauthenticationVerification: + err = mr.ReauthenticateMail(r, u, otp) + case mail.RecoveryVerification: + err = mr.RecoveryMail(r, u, otp, referrerURL, externalURL) + case mail.InviteVerification: + err = mr.InviteMail(r, u, otp, referrerURL, externalURL) + case mail.EmailChangeVerification: + err = mr.EmailChangeMail(r, u, otpNew, otp, referrerURL, externalURL) + default: + err = errors.New("invalid email action type") + } + + switch { + case errors.Is(err, mail.ErrInvalidEmailAddress), + errors.Is(err, mail.ErrInvalidEmailFormat), + errors.Is(err, mail.ErrInvalidEmailDNS): + return apierrors.NewBadRequestError( + apierrors.ErrorCodeEmailAddressInvalid, + "Email address %q is invalid", + u.GetEmail()) + default: + return err + } + }) } diff --git a/internal/api/signup.go b/internal/api/signup.go index 09ac43524..c8b254f14 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -11,6 +11,7 @@ import ( "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/provider" "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/api/taskafter" "github.com/supabase/auth/internal/metering" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" @@ -294,7 +295,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - return sendJSON(w, http.StatusOK, sanitizedUser) + return taskafter.SendJSON(ctx, w, http.StatusOK, sanitizedUser) } return err } diff --git a/internal/api/taskafter/taskafter.go b/internal/api/taskafter/taskafter.go new file mode 100644 index 000000000..14a5ea1a3 --- /dev/null +++ b/internal/api/taskafter/taskafter.go @@ -0,0 +1,195 @@ +// Package taskafter contains utilities for contextually queueing and firing +// tasks in FIFO order. +package taskafter + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "sync" + + pkgerrors "github.com/pkg/errors" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/observability" +) + +func Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(With(r.Context())) + defer func() { + if err := Fire(r.Context()); err != nil { + log := observability.GetLogEntry(r).Entry + log.WithError(err).Warn("error running 1 or more tasks") + } + }() + next.ServeHTTP(w, r) + }) +} + +type task struct { + name string + fn func() error +} + +type state struct { + mu sync.Mutex + done bool + queue []*task + seen map[string]bool + res *response +} + +type response struct { + w http.ResponseWriter + status int + obj any +} + +func newState() *state { + return &state{ + seen: make(map[string]bool), + } +} + +func (o *state) respond() error { + if o.res == nil { + return nil + } + + o.res.w.Header().Set("Content-Type", "application/json") + b, err := json.Marshal(o.res.obj) + if err != nil { + msg := fmt.Sprintf("Error encoding json response: %v", o.res.obj) + return pkgerrors.Wrap(err, msg) + } + o.res.w.WriteHeader(o.res.status) + _, err = o.res.w.Write(b) + return err +} + +func (o *state) fire() error { + o.mu.Lock() + defer o.mu.Unlock() + if o.done { + err := fmt.Errorf("%w: duplicate call to Fire", errPkg) + return apierrors.NewInternalServerError( + "error tasking hooks").WithInternalError(err) + } + o.done = true + + var errs []error + for _, tr := range o.queue { + err := tr.fn() + if err != nil { + errs = append(errs, fmt.Errorf("%v: %w", tr.name, err)) + } + } + if err := o.respond(); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) +} + +func (o *state) add(name string, fn func() error) error { + o.mu.Lock() + defer o.mu.Unlock() + if o.done { + err := fmt.Errorf("%w: unable to add tasks after a call to Fire", errPkg) + return apierrors.NewInternalServerError( + "failed to add task").WithInternalError(err) + } + if name != "" { + if o.seen[name] { + return nil + } + o.seen[name] = true + } + + tr := &task{ + fn: fn, + name: name, + } + o.queue = append(o.queue, tr) + return nil +} + +var ( + ctxKey = new(int) + errPkg = errors.New("taskafter") + errCtxInternal = fmt.Errorf( + "%w: context is missing *taskafter.state", errPkg) + errCtx = apierrors.NewInternalServerError( + "unable to queue or run tasks"). + WithInternalError(errCtxInternal) +) + +// Fire will call each queued task previously queued with Defer and return a nil +// error. If err is non-nil it will be 1 or more errors that occurred during +// firing joined by errors.Join(). +func Fire(ctx context.Context) error { + st := from(ctx) + if st == nil { + return errCtx + } + return st.fire() +} + +// Once will queue the first task given by name to run at the end of the request +// in FIFO order or return an error if Fire has already been called. +func Once(ctx context.Context, name string, taskFn func() error) error { + return add(ctx, name, taskFn) +} + +// Queue will queue a task to run at the end of the request in FIFO order or +// return an error if Fire has already been called. +func Queue(ctx context.Context, taskFn func() error) error { + return add(ctx, "", taskFn) +} + +// SendJSON sets the response to be sent at the end of Fire(). +func SendJSON( + ctx context.Context, + w http.ResponseWriter, + status int, + obj interface{}, +) error { + st := from(ctx) + if st == nil { + return errCtx + } + st.mu.Lock() + defer st.mu.Unlock() + + st.res = &response{ + w: w, + status: status, + obj: obj, + } + return nil +} + +func add(ctx context.Context, name string, taskFn func() error) error { + st := from(ctx) + if st == nil { + return errCtx + } + return st.add(name, taskFn) +} + +// With sets up the given context for adding tasks. +func With(ctx context.Context) context.Context { + st := from(ctx) + if st == nil { + st = newState() + } + return context.WithValue(ctx, ctxKey, st) +} + +func from(ctx context.Context) *state { + if st, ok := ctx.Value(ctxKey).(*state); ok && st != nil { + return st + } + return nil +} diff --git a/internal/api/taskafter/taskafter_test.go b/internal/api/taskafter/taskafter_test.go new file mode 100644 index 000000000..44cc19b69 --- /dev/null +++ b/internal/api/taskafter/taskafter_test.go @@ -0,0 +1,116 @@ +package taskafter + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestContext(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err := Once(ctx, `any`, func() error { return nil }) + require.Equal(t, errCtx, err) + require.Equal(t, errCtx, Fire(ctx)) + + ctx = With(ctx) + st := from(ctx) + require.NotNil(t, st) + require.Equal(t, st, from(ctx)) + + err = Once(ctx, `any`, func() error { return nil }) + require.NoError(t, err) + + err = Queue(ctx, func() error { return nil }) + require.NoError(t, err) + + err = Fire(ctx) + require.NoError(t, err) +} + +func TestState(t *testing.T) { + var calls []string + triggerFn := func(name string) (string, func() error) { + return name, func() error { + calls = append(calls, name) + return nil + } + } + + taskNames := []string{ + `after-user-created`, + `after-identity-created`, + `after-identity-linking`, + } + + st := newState() + for _, taskName := range taskNames { + err := st.add(triggerFn(taskName)) + require.NoError(t, err) + } + + require.Equal(t, 0, len(calls)) + + err := st.fire() + require.NoError(t, err) + require.Equal(t, len(taskNames), len(calls)) + + for i, taskName := range taskNames { + require.Equal(t, taskName, calls[i]) + } + + // double fire fails + if err := st.fire(); err == nil { + t.Fatal("exp non-nil err") + } +} + +func TestStateErrors(t *testing.T) { + var calls []string + sentinel := errors.New("sentinel") + triggerFn := func(name string) (string, func() error) { + return name, func() error { + calls = append(calls, name) + return sentinel + } + } + + taskNames := []string{ + `after-user-created`, + `after-identity-created`, + `after-identity-linking`, + } + + st := newState() + for _, taskName := range taskNames { + if err := st.add(triggerFn(taskName)); err != nil { + t.Fatalf("exp nil error; got %v", err) + } + + // double trigger should just be ignored, less burden on callers + require.NoError(t, st.add(triggerFn(taskName))) + } + require.Equal(t, 0, len(calls)) + + fireErr := st.fire() + require.Error(t, fireErr) + require.Equal(t, len(taskNames), len(calls)) + + var b strings.Builder + for i, taskName := range taskNames { + require.Equal(t, taskName, calls[i]) + b.WriteString(string(taskName) + ": sentinel\n") + } + + expErrStr := strings.TrimRight(b.String(), "\n") + require.Equal(t, expErrStr, fireErr.Error()) + + // double fire fails + require.Error(t, st.fire()) + require.Error(t, st.add(triggerFn(`any`))) +}