diff --git a/.env.example b/.env.example index 26c8b26..7f0f2a8 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,10 @@ TW_ENV=development TW_LOG_LEVEL=INFO +TW_OTP_PROVIDER=mock +TW_ALLOWED_EMAIL_DOMAINS=@schools.gov.sg +TW_MOCK_ALLOWED_EMAILS=tracy_lim@schools.gov.sg + # TW_VITE_DEV_SERVER_URL=http://localhost:5173 # TW_BUNDLE_DIRECTORY=dist diff --git a/server/cmd/tw/main.go b/server/cmd/tw/main.go index ee1709a..8e78350 100644 --- a/server/cmd/tw/main.go +++ b/server/cmd/tw/main.go @@ -15,6 +15,7 @@ import ( "github.com/String-sg/teacher-workspace/server/internal/config" "github.com/String-sg/teacher-workspace/server/internal/handler" "github.com/String-sg/teacher-workspace/server/internal/middleware" + "github.com/String-sg/teacher-workspace/server/internal/otp" "github.com/String-sg/teacher-workspace/server/pkg/dotenv" "golang.org/x/sync/errgroup" ) @@ -58,7 +59,23 @@ func main() { } func run(ctx context.Context, cfg *config.Config) error { - h, err := handler.New(cfg) + var otpProvider otp.Provider + switch cfg.OTP.Provider { + case config.OTPProviderMock: + otpProvider = otp.NewMockProvider(cfg.OTP.Mock.AllowedEmails) + case config.OTPProviderOTPaaS: + otpProvider = otp.NewOTPaaSProvider( + cfg.OTP.OTPaaS.Host, + cfg.OTP.OTPaaS.ID, + cfg.OTP.OTPaaS.Namespace, + cfg.OTP.OTPaaS.Secret, + cfg.OTP.OTPaaS.Timeout, + ) + default: + return fmt.Errorf("unsupported OTP provider: %q", cfg.OTP.Provider) + } + + h, err := handler.New(cfg, otpProvider) if err != nil { return fmt.Errorf("create handler: %w", err) } diff --git a/server/internal/config/config.go b/server/internal/config/config.go index 0e90a80..3a649a0 100644 --- a/server/internal/config/config.go +++ b/server/internal/config/config.go @@ -10,10 +10,14 @@ import ( ) type Environment string +type Provider string const ( EnvironmentDevelopment Environment = "development" EnvironmentProduction Environment = "production" + + OTPProviderOTPaaS Provider = "otpaas" + OTPProviderMock Provider = "mock" ) // Config is the main configuration for the application. @@ -21,11 +25,13 @@ type Config struct { Environment Environment `dotenv:"TW_ENV"` LogLevel slog.Level `dotenv:"TW_LOG_LEVEL"` + AllowedEmailDomains []string `dotenv:"TW_ALLOWED_EMAIL_DOMAINS"` + ViteDevServerURL *url.URL `dotenv:"TW_VITE_DEV_SERVER_URL"` BundleDirectory string `dotenv:"TW_BUNDLE_DIRECTORY"` Server ServerConfig `dotenv:",squash"` - OTPaaS OTPaaSConfig `dotenv:",squash"` + OTP OTPConfig `dotenv:",squash"` } // ServerConfig represents the configuration for the HTTP server. @@ -38,6 +44,12 @@ type ServerConfig struct { IdleTimeout time.Duration `dotenv:"TW_SERVER_IDLE_TIMEOUT"` } +type OTPConfig struct { + Provider Provider `dotenv:"TW_OTP_PROVIDER"` + OTPaaS OTPaaSConfig `dotenv:",squash"` + Mock MockConfig `dotenv:",squash"` +} + type OTPaaSConfig struct { Host string `dotenv:"TW_OTPAAS_HOST"` ID string `dotenv:"TW_OTPAAS_ID"` @@ -47,12 +59,32 @@ type OTPaaSConfig struct { Timeout time.Duration `dotenv:"TW_OTPAAS_TIMEOUT"` } +type MockConfig struct { + AllowedEmails []string `dotenv:"TW_MOCK_ALLOWED_EMAILS"` +} + // Default returns the default configuration for the application. func Default() *Config { return &Config{ Environment: EnvironmentDevelopment, LogLevel: slog.LevelInfo, + OTP: OTPConfig{ + Provider: OTPProviderOTPaaS, + OTPaaS: OTPaaSConfig{ + Host: "https://otp.techpass.suite.gov.sg", + ID: "", + Namespace: "", + Secret: "", + Timeout: 10 * time.Second, + }, + Mock: MockConfig{ + AllowedEmails: nil, + }, + }, + + AllowedEmailDomains: []string{"@schools.gov.sg"}, + ViteDevServerURL: must(url.Parse("http://localhost:5173")), BundleDirectory: "dist", @@ -64,14 +96,6 @@ func Default() *Config { WriteTimeout: 30 * time.Second, IdleTimeout: 60 * time.Second, }, - - OTPaaS: OTPaaSConfig{ - Host: "https://otp.techpass.suite.gov.sg", - ID: "", - Namespace: "", - Secret: "", - Timeout: 10 * time.Second, - }, } } @@ -82,6 +106,12 @@ func (cfg *Config) Validate() error { errs = append(errs, fmt.Errorf("TW_ENV must be one of %q or %q; got %q", EnvironmentDevelopment, EnvironmentProduction, cfg.Environment)) } + if len(cfg.AllowedEmailDomains) == 0 { + errs = append(errs, errors.New("TW_ALLOWED_EMAIL_DOMAINS is required")) + } + + errs = append(errs, cfg.Server.validate()) + switch cfg.Environment { case EnvironmentDevelopment: if cfg.ViteDevServerURL.Scheme != "http" && cfg.ViteDevServerURL.Scheme != "https" { @@ -100,7 +130,16 @@ func (cfg *Config) Validate() error { } } - return errors.Join(append(errs, cfg.Server.validate(), cfg.OTPaaS.validate())...) + switch cfg.OTP.Provider { + case OTPProviderOTPaaS: + errs = append(errs, cfg.OTP.OTPaaS.validate()) + case OTPProviderMock: + errs = append(errs, cfg.OTP.Mock.validate()) + default: + errs = append(errs, fmt.Errorf("TW_OTP_PROVIDER must be one of %q or %q; got %q", OTPProviderOTPaaS, OTPProviderMock, cfg.OTP.Provider)) + } + + return errors.Join(errs...) } func (c ServerConfig) validate() error { @@ -147,6 +186,16 @@ func (c OTPaaSConfig) validate() error { return errors.Join(errs...) } +func (c MockConfig) validate() error { + var errs []error + + if len(c.AllowedEmails) == 0 { + errs = append(errs, errors.New("TW_MOCK_ALLOWED_EMAILS is required")) + } + + return errors.Join(errs...) +} + func must[T any](value T, err error) T { if err != nil { panic(err) diff --git a/server/internal/handler/handler.go b/server/internal/handler/handler.go index fa9ae01..fe99e86 100644 --- a/server/internal/handler/handler.go +++ b/server/internal/handler/handler.go @@ -9,6 +9,7 @@ import ( "github.com/String-sg/teacher-workspace/server/internal/config" "github.com/String-sg/teacher-workspace/server/internal/htmlutil" "github.com/String-sg/teacher-workspace/server/internal/middleware" + "github.com/String-sg/teacher-workspace/server/internal/otp" ) const ( @@ -30,16 +31,16 @@ type Handler struct { cfg *config.Config executor htmlutil.TemplateExecutor - client *http.Client - proxy *httputil.ReverseProxy assets http.Handler + + otpProvider otp.Provider } -func New(cfg *config.Config) (*Handler, error) { +func New(cfg *config.Config, otpProvider otp.Provider) (*Handler, error) { h := &Handler{ - cfg: cfg, - client: &http.Client{Timeout: cfg.OTPaaS.Timeout}, + cfg: cfg, + otpProvider: otpProvider, } switch cfg.Environment { diff --git a/server/internal/handler/otp.go b/server/internal/handler/otp.go index b1d50b9..afda8f5 100644 --- a/server/internal/handler/otp.go +++ b/server/internal/handler/otp.go @@ -1,22 +1,18 @@ package handler import ( - "bytes" - "context" - "crypto/hmac" "crypto/rand" - "crypto/sha256" "encoding/base64" - "encoding/hex" "encoding/json" "errors" - "io" "log/slog" + "mime" "net/http" "strings" "github.com/String-sg/teacher-workspace/server/internal/config" "github.com/String-sg/teacher-workspace/server/internal/middleware" + "github.com/String-sg/teacher-workspace/server/internal/otp" ) var store = make(map[string]map[string]string) @@ -33,24 +29,15 @@ type errorBody struct { } const ( - ErrorCodeInvalidForm = "INVALID_FORM" - ErrorCodeInvalidAuth = "AUTHORIZATION_FAILED" - ErrorCodeInternalServerError = "INTERNAL_SERVER_ERROR" - ErrorCodeRequestTimeout = "REQUEST_TIMEOUT" + ErrorCodeInvalidForm = "INVALID_FORM" + ErrorCodeInvalidAuth = "AUTHORIZATION_FAILED" + ErrorCodeRateLimited = "RATE_LIMITED" ) -// isAllowedEmail returns true if the email domain is allowed for the given environment. -// Production: only @schools.gov.sg. Staging/development: @schools.gov.sg or @tech.gov.sg. -func isAllowedEmail(email string, env config.Environment) bool { - if env == config.EnvironmentProduction { - return strings.HasSuffix(email, "@schools.gov.sg") - } - return strings.HasSuffix(email, "@schools.gov.sg") || strings.HasSuffix(email, "@tech.gov.sg") -} - // writeClientErrorResponse writes a JSON error response for 4xx client errors func writeClientErrorResponse(w http.ResponseWriter, logger *slog.Logger, statusCode int, code string, message string, errors ...errorBody) { - w.Header().Set("Content-Type", "application/json") + w.Header().Set(HeaderContentType, MIMEApplicationJSONCharsetUTF8) + w.Header().Set(HeaderXContentTypeOptions, "nosniff") w.WriteHeader(statusCode) if err := json.NewEncoder(w).Encode(errorResponse{ Code: code, @@ -63,19 +50,10 @@ func writeClientErrorResponse(w http.ResponseWriter, logger *slog.Logger, status // writeServerErrorResponse writes a plain text error response for 5xx server errors using http.Error func writeServerErrorResponse(w http.ResponseWriter, statusCode int, message string) { + w.Header().Set(HeaderXContentTypeOptions, "nosniff") http.Error(w, message, statusCode) } -func buildAuthToken(appId, appNamespace, appSecret string) string { - mac := hmac.New(sha256.New, []byte(appSecret)) - mac.Write([]byte(appId)) - - sig := hex.EncodeToString(mac.Sum(nil)) - payload := appNamespace + ":" + appId + ":" + sig - - return base64.StdEncoding.EncodeToString([]byte(payload)) -} - type requestOTPRequest struct { Email string `json:"email"` } @@ -84,95 +62,70 @@ type requestOTPResponse struct { ID string `json:"id"` } -type requestOTPOTPaasRequest struct { - Email string `json:"email"` -} - -type requestOTPPaasResponse struct { - ID string `json:"id"` -} - func (h *Handler) RequestOTP(w http.ResponseWriter, r *http.Request) { logger := middleware.LoggerFromContext(r.Context()) + mediaType, _, err := mime.ParseMediaType(r.Header.Get(HeaderContentType)) + + if err != nil || mediaType != MIMEApplicationJSON { + writeClientErrorResponse(w, logger, http.StatusUnsupportedMediaType, ErrorCodeInvalidForm, "One or more input has an error") + logger.Error("Content-Type must be application/json", "err", err) + return + } + var input requestOTPRequest if err := json.NewDecoder(r.Body).Decode(&input); err != nil { writeClientErrorResponse(w, logger, http.StatusBadRequest, ErrorCodeInvalidForm, "One or more input has an error") - logger.Error("Email not found in request body", "err", err) + logger.Error("Problem parsing JSON", "err", err) return } - if !isAllowedEmail(input.Email, h.cfg.Environment) { + email := strings.TrimSpace(strings.ToLower(input.Email)) + + if email == "" { writeClientErrorResponse(w, logger, http.StatusUnprocessableEntity, ErrorCodeInvalidForm, "One or more input has an error") - logger.Error("Email is not a valid schools.gov.sg email") + logger.Error("Email required") return } - otpaasPayload := requestOTPOTPaasRequest(input) - payload, err := json.Marshal(otpaasPayload) - if err != nil { - writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Failed to marshal request body", "err", err) - return + hasValidDomain := false + for _, domain := range h.cfg.AllowedEmailDomains { + if strings.HasSuffix(email, domain) { + hasValidDomain = true + break + } } - req, err := http.NewRequest("POST", h.cfg.OTPaaS.Host+"/otp", bytes.NewReader(payload)) - if err != nil { - writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Failed to create request", "err", err) + if !hasValidDomain { + writeClientErrorResponse(w, logger, http.StatusUnprocessableEntity, ErrorCodeInvalidForm, "One or more input has an error") + logger.Error("Email domain not allowed") return } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+buildAuthToken(h.cfg.OTPaaS.Secret, h.cfg.OTPaaS.ID, h.cfg.OTPaaS.Namespace)) - req.Header.Set("X-App-Id", h.cfg.OTPaaS.ID) - req.Header.Set("X-App-Namespace", h.cfg.OTPaaS.Namespace) + flowID, err := h.otpProvider.RequestOTP(r.Context(), email) - resp, err := h.client.Do(req) if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - writeServerErrorResponse(w, http.StatusGatewayTimeout, "Request timeout. Please try again later.") - logger.Error("OTPaas request timeout", "err", err) + switch { + case errors.Is(err, otp.ErrRateLimited): + writeClientErrorResponse(w, logger, http.StatusTooManyRequests, ErrorCodeRateLimited, "Too many requests") + logger.Error("OTP request rate limited", "email", email) + return + case errors.Is(err, otp.ErrDomainNotAllowed): + writeClientErrorResponse(w, logger, http.StatusUnprocessableEntity, ErrorCodeInvalidForm, "One or more input has an error") + logger.Error("OTP request rejected: email domain not allowed", "email", email) + return + case errors.Is(err, otp.ErrEmailNotAllowed): + writeClientErrorResponse(w, logger, http.StatusUnprocessableEntity, ErrorCodeInvalidForm, "One or more input has an error") + logger.Error("OTP request rejected: email not allowed", "email", email) + return + default: + writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") + logger.Error("failed to request OTP", "err", err) return } - - writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Error sending request to OTPaas", "err", err) - return - } - - if resp.StatusCode != http.StatusOK { - // TODO: update the error message from figma when available - writeServerErrorResponse(w, http.StatusInternalServerError, "Something went wrong. Please try again later.") - logger.Error("Authorization failed", "err", err, "otpaas_status_code", resp.StatusCode) - return } - defer func() { - if err := resp.Body.Close(); err != nil { - logger.Error("Failed to close response body", "err", err) - } - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Failed to read response body", "err", err, "otpaas_status_code", resp.StatusCode) - return - } - - var otpResp requestOTPPaasResponse - if err := json.Unmarshal(body, &otpResp); err != nil { - writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Failed to unmarshal response body", "err", err, "otpaas_status_code", resp.StatusCode) - return - } - - if otpResp.ID == "" { - writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Failed to get `otp_flow_id` from OTPaas", "err", err, "otpaas_status_code", resp.StatusCode) - return - } + logger.Info("OTP requested", "email", email, "flow_id", flowID) var sessionID string @@ -181,7 +134,7 @@ func (h *Handler) RequestOTP(w http.ResponseWriter, r *http.Request) { id := make([]byte, 32) if _, err := rand.Read(id); err != nil { writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Failed to generate session ID", "err", err, "otpaas_status_code", resp.StatusCode) + logger.Error("Failed to generate session ID", "err", err) return } sessionID = base64.RawURLEncoding.EncodeToString(id) @@ -189,7 +142,7 @@ func (h *Handler) RequestOTP(w http.ResponseWriter, r *http.Request) { sessionID = c.Value } - store[sessionID] = map[string]string{"otp_flow_id": otpResp.ID} + store[sessionID] = map[string]string{"otp_flow_id": flowID} cookie := http.Cookie{ Name: "session_id", @@ -201,10 +154,11 @@ func (h *Handler) RequestOTP(w http.ResponseWriter, r *http.Request) { } http.SetCookie(w, &cookie) - w.Header().Set("Content-Type", "application/json") + w.Header().Set(HeaderContentType, MIMEApplicationJSONCharsetUTF8) + w.Header().Set(HeaderXContentTypeOptions, "nosniff") w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(requestOTPResponse(otpResp)); err != nil { + if err := json.NewEncoder(w).Encode(requestOTPResponse{ID: flowID}); err != nil { logger.Error("Failed to encode error response", "err", err) } @@ -214,23 +168,39 @@ type verifyOTPRequest struct { PIN string `json:"pin"` } -type verifyOTPOTPaasRequest struct { - PIN string `json:"pin"` +type verifyOTPResponse struct { + ID string `json:"id"` + Email string `json:"email"` } func (h *Handler) VerifyOTP(w http.ResponseWriter, r *http.Request) { logger := middleware.LoggerFromContext(r.Context()) + mediaType, _, err := mime.ParseMediaType(r.Header.Get(HeaderContentType)) + if err != nil || mediaType != MIMEApplicationJSON { + writeClientErrorResponse(w, logger, http.StatusUnsupportedMediaType, ErrorCodeInvalidForm, "One or more input has an error") + logger.Error("Content-Type must be application/json", "err", err) + return + } + var input verifyOTPRequest if err := json.NewDecoder(r.Body).Decode(&input); err != nil { writeClientErrorResponse(w, logger, http.StatusBadRequest, ErrorCodeInvalidForm, "One or more input has an error") - logger.Error("Pin not found in request body", "err", err) + logger.Error("Problem parsing JSON", "err", err) + return + } + + pin := strings.TrimSpace(input.PIN) + + if pin == "" { + writeClientErrorResponse(w, logger, http.StatusUnprocessableEntity, ErrorCodeInvalidForm, "One or more input has an error") + logger.Error("PIN required") return } - if len(input.PIN) != 6 { + if !isValidPIN(pin) { writeClientErrorResponse(w, logger, http.StatusUnprocessableEntity, ErrorCodeInvalidForm, "One or more input has an error") - logger.Error("Pin is not a valid 6 digit PIN") + logger.Error("PIN must be 6 digits") return } @@ -243,67 +213,52 @@ func (h *Handler) VerifyOTP(w http.ResponseWriter, r *http.Request) { session, ok := store[c.Value] if !ok { - // TODO: update the error message from figma when available writeClientErrorResponse(w, logger, http.StatusUnauthorized, ErrorCodeInvalidAuth, "Failed to authenticate session.") - logger.Error("Session not found in store", "err", err) - return - } - - otpaasPayload := verifyOTPOTPaasRequest(input) - payload, err := json.Marshal(otpaasPayload) - if err != nil { - writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Failed to marshal request body", "err", err) - return - } - - otpFlowID := session["otp_flow_id"] - - req, err := http.NewRequest("PUT", h.cfg.OTPaaS.Host+"/otp/"+otpFlowID, bytes.NewReader(payload)) - if err != nil { - writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Failed to create request", "err", err) + logger.Error("Session not found in store") return } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+buildAuthToken(h.cfg.OTPaaS.Secret, h.cfg.OTPaaS.ID, h.cfg.OTPaaS.Namespace)) - req.Header.Set("X-App-Id", h.cfg.OTPaaS.ID) - req.Header.Set("X-App-Namespace", h.cfg.OTPaaS.Namespace) + flowID := session["otp_flow_id"] - resp, err := h.client.Do(req) + email, err := h.otpProvider.VerifyOTP(r.Context(), flowID, pin) if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - writeServerErrorResponse(w, http.StatusGatewayTimeout, "Request timeout. Please try again later.") - logger.Error("OTPaas request timeout", "err", err) - return - } - - writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Error sending request to OTPaas", "err", err) - return - } - - if resp.StatusCode != http.StatusOK { - switch resp.StatusCode { - case http.StatusUnauthorized: + switch { + case errors.Is(err, otp.ErrInvalidPIN): writeClientErrorResponse(w, logger, http.StatusUnprocessableEntity, ErrorCodeInvalidAuth, "Failed to authenticate session.") - logger.Error("Invalid PIN", "err", err, "otpaas_status_code", resp.StatusCode) - case http.StatusNotFound: + logger.Warn("OTP verify failed: invalid pin", "flow_id", flowID) + return + case errors.Is(err, otp.ErrFlowExpired): writeClientErrorResponse(w, logger, http.StatusUnprocessableEntity, ErrorCodeInvalidAuth, "Failed to authenticate session.") - logger.Error("Pin expired", "err", err, "otpaas_status_code", resp.StatusCode) + logger.Warn("OTP verify failed: flow expired", "flow_id", flowID) + return default: writeServerErrorResponse(w, http.StatusInternalServerError, "Internal server error") - logger.Error("Internal server error", "err", err, "otpaas_status_code", resp.StatusCode) + logger.Error("failed to verify OTP", "err", err, "flow_id", flowID) return } } - defer func() { - if err := resp.Body.Close(); err != nil { - logger.Error("Failed to close response body", "err", err) + logger.Info("OTP verified", "flow_id", flowID, "email", email) + + w.Header().Set(HeaderContentType, MIMEApplicationJSONCharsetUTF8) + w.Header().Set(HeaderXContentTypeOptions, "nosniff") + w.WriteHeader(http.StatusOK) + + if err := json.NewEncoder(w).Encode(verifyOTPResponse{ID: flowID, Email: email}); err != nil { + logger.Error("Failed to encode response", "err", err) + } +} + +func isValidPIN(pin string) bool { + if len(pin) != 6 { + return false + } + + for _, char := range pin { + if char < '0' || char > '9' { + return false } - }() + } - w.WriteHeader(http.StatusNoContent) + return true } diff --git a/server/internal/handler/otp_test.go b/server/internal/handler/otp_test.go index f067e4d..efecad0 100644 --- a/server/internal/handler/otp_test.go +++ b/server/internal/handler/otp_test.go @@ -2,524 +2,620 @@ package handler import ( "bytes" + "context" "encoding/json" - "io" + "fmt" "net/http" "net/http/httptest" "testing" - "time" "github.com/String-sg/teacher-workspace/server/internal/config" + "github.com/String-sg/teacher-workspace/server/internal/otp" "github.com/String-sg/teacher-workspace/server/pkg/require" ) -func resetStore() { - store = make(map[string]map[string]string) +type stubOTPProvider struct { + requestOTP func(ctx context.Context, email string) (string, error) + verifyOTP func(ctx context.Context, flowID, pin string) (string, error) } -type RoundTripperFunc func(*http.Request) (*http.Response, error) +func (s stubOTPProvider) RequestOTP(ctx context.Context, email string) (string, error) { + if s.requestOTP == nil { + return "", nil + } -func (f RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { - return f(r) + return s.requestOTP(ctx, email) } -func TestRequestOTP_SuccessProduction(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader([]byte(`{"id": "123"}`)))}, nil - }) - cfg := config.Default() - cfg.Environment = config.EnvironmentProduction - h := &Handler{cfg: cfg, client: &http.Client{Transport: rt}} - resetStore() +func (s stubOTPProvider) VerifyOTP(ctx context.Context, flowID, pin string) (string, error) { + if s.verifyOTP == nil { + return "", nil + } - payload := map[string]string{"email": "test@schools.gov.sg"} - b, _ := json.Marshal(payload) + return s.verifyOTP(ctx, flowID, pin) +} - req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewReader(b)) +func resetStore() { + store = make(map[string]map[string]string) +} - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() +func TestHandler_RequestOTP(t *testing.T) { + t.Run("successful response", func(t *testing.T) { + resetStore() + gotEmail := "" + h := &Handler{ + cfg: &config.Config{ + AllowedEmailDomains: []string{"@schools.gov.sg"}, + }, + otpProvider: stubOTPProvider{ + requestOTP: func(ctx context.Context, email string) (string, error) { + gotEmail = email + return "flow-123", nil + }, + }, + } - h.RequestOTP(rec, req) + req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewBufferString(`{"email":"Teacher@schools.gov.sg"}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) + rec := httptest.NewRecorder() - res := rec.Result() + h.RequestOTP(rec, req) - require.Equal(t, http.StatusOK, res.StatusCode) + res := rec.Result() - var gotOTPResponse requestOTPResponse - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &gotOTPResponse)) - require.Equal(t, "123", gotOTPResponse.ID) + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) + require.Equal(t, "teacher@schools.gov.sg", gotEmail) - // Should set/refresh the same cookie value. - var got *http.Cookie - for _, c := range res.Cookies() { - if c.Name == "session_id" { - got = c - break - } - } - require.True(t, got != nil) - require.Equal(t, "abc", got.Value) + var resp requestOTPResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&resp)) + require.Equal(t, "flow-123", resp.ID) - session, ok := store["abc"] - require.True(t, ok) - require.Equal(t, "123", session["otp_flow_id"]) -} + var got *http.Cookie + for _, c := range res.Cookies() { + if c.Name == "session_id" { + got = c + break + } + } + require.True(t, got != nil) + require.Equal(t, "abc", got.Value) -func TestRequestOTP_SuccessDevelopment(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader([]byte(`{"id": "123"}`)))}, nil + session, ok := store["abc"] + require.True(t, ok) + require.Equal(t, "flow-123", session["otp_flow_id"]) }) - cfg := config.Default() - cfg.Environment = config.EnvironmentDevelopment - h := &Handler{cfg: cfg, client: &http.Client{Transport: rt}} - resetStore() - - payload := map[string]string{"email": "test@tech.gov.sg"} - b, _ := json.Marshal(payload) - - req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewReader(b)) - - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + t.Run("creates new session when no cookie", func(t *testing.T) { + resetStore() + h := &Handler{ + cfg: &config.Config{ + AllowedEmailDomains: []string{"@schools.gov.sg"}, + }, + otpProvider: stubOTPProvider{ + requestOTP: func(ctx context.Context, email string) (string, error) { + return "flow-456", nil + }, + }, + } - h.RequestOTP(rec, req) + req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewBufferString(`{"email":"teacher@schools.gov.sg"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() - res := rec.Result() + h.RequestOTP(rec, req) - require.Equal(t, http.StatusOK, res.StatusCode) + res := rec.Result() - var gotOTPResponse requestOTPResponse - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &gotOTPResponse)) - require.Equal(t, "123", gotOTPResponse.ID) + require.Equal(t, http.StatusOK, res.StatusCode) - // Should set/refresh the same cookie value. - var got *http.Cookie - for _, c := range res.Cookies() { - if c.Name == "session_id" { - got = c - break + var got *http.Cookie + for _, c := range res.Cookies() { + if c.Name == "session_id" { + got = c + break + } } - } - require.True(t, got != nil) - require.Equal(t, "abc", got.Value) + require.True(t, got != nil) + require.True(t, got.Value != "") - session, ok := store["abc"] - require.True(t, ok) - require.Equal(t, "123", session["otp_flow_id"]) -} + session, ok := store[got.Value] + require.True(t, ok) + require.Equal(t, "flow-456", session["otp_flow_id"]) + }) -func TestRequestOTP_MissingEmail(t *testing.T) { - h := &Handler{cfg: config.Default(), client: &http.Client{}} - resetStore() + t.Run("unsupported media type", func(t *testing.T) { + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{}, + } - req := httptest.NewRequest(http.MethodPost, "/otp/request", nil) - rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewBufferString(`{"email":"teacher@schools.gov.sg"}`)) + req.Header.Set("Content-Type", "text/plain") + rec := httptest.NewRecorder() - h.RequestOTP(rec, req) + h.RequestOTP(rec, req) - res := rec.Result() - require.Equal(t, http.StatusBadRequest, res.StatusCode) + res := rec.Result() - var got errorResponse - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) - require.Equal(t, "INVALID_FORM", got.Code) - require.Equal(t, "One or more input has an error", got.Message) -} + require.Equal(t, http.StatusUnsupportedMediaType, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) -func TestRequestOTP_InvalidEmail(t *testing.T) { - h := &Handler{cfg: config.Default(), client: &http.Client{}} - resetStore() + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidForm, errResp.Code) + require.Equal(t, "One or more input has an error", errResp.Message) + }) - payload := map[string]string{"email": "test@example.com"} - b, _ := json.Marshal(payload) + t.Run("invalid JSON", func(t *testing.T) { + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{}, + } - req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewReader(b)) - rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewBufferString(`{`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() - h.RequestOTP(rec, req) + h.RequestOTP(rec, req) - res := rec.Result() - require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + res := rec.Result() - var got errorResponse - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) - require.Equal(t, "INVALID_FORM", got.Code) - require.Equal(t, "One or more input has an error", got.Message) -} + require.Equal(t, http.StatusBadRequest, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) -func TestRequestOTP_InvalidEmailProduction(t *testing.T) { - cfg := config.Default() - cfg.Environment = config.EnvironmentProduction - h := &Handler{cfg: cfg, client: &http.Client{}} - resetStore() + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidForm, errResp.Code) + require.Equal(t, "One or more input has an error", errResp.Message) + }) - payload := map[string]string{"email": "test@tech.gov.sg"} - b, _ := json.Marshal(payload) + t.Run("empty email", func(t *testing.T) { + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{}, + } - req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewReader(b)) - rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewBufferString(`{"email":" "}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() - h.RequestOTP(rec, req) + h.RequestOTP(rec, req) - res := rec.Result() - require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + res := rec.Result() - var got errorResponse - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) - require.Equal(t, "INVALID_FORM", got.Code) - require.Equal(t, "One or more input has an error", got.Message) -} + require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) -func TestRequestOTP_Timeout(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - select { - case <-req.Context().Done(): - return nil, req.Context().Err() - case <-time.After(200 * time.Millisecond): - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(`{"id":"123"}`))), - }, nil - } + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidForm, errResp.Code) + require.Equal(t, "One or more input has an error", errResp.Message) }) - h := &Handler{cfg: config.Default(), client: &http.Client{Timeout: 10 * time.Millisecond, Transport: rt}} - resetStore() - - payload := map[string]string{"email": "test@schools.gov.sg"} - b, _ := json.Marshal(payload) + t.Run("domain not allowed", func(t *testing.T) { + h := &Handler{ + cfg: &config.Config{ + AllowedEmailDomains: []string{"@schools.gov.sg"}, + }, + otpProvider: stubOTPProvider{}, + } - req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewReader(b)) + req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewBufferString(`{"email":"teacher@example.com"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + h.RequestOTP(rec, req) - h.RequestOTP(rec, req) + res := rec.Result() - res := rec.Result() + require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) - require.Equal(t, http.StatusGatewayTimeout, res.StatusCode) - require.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) - require.Equal(t, "Request timeout. Please try again later.\n", rec.Body.String()) -} - -func TestRequestOTP_NotAuthorized(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusUnauthorized, Body: io.NopCloser(bytes.NewReader([]byte(`{}`)))}, nil + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidForm, errResp.Code) + require.Equal(t, "One or more input has an error", errResp.Message) }) - h := &Handler{cfg: config.Default(), client: &http.Client{Transport: rt}} - resetStore() - - payload := map[string]string{"email": "test@schools.gov.sg"} - b, _ := json.Marshal(payload) + t.Run("provider rate limited", func(t *testing.T) { + resetStore() + h := &Handler{ + cfg: &config.Config{ + AllowedEmailDomains: []string{"@schools.gov.sg"}, + }, + otpProvider: stubOTPProvider{ + requestOTP: func(context.Context, string) (string, error) { + return "", fmt.Errorf("request OTP: %w", otp.ErrRateLimited) + }, + }, + } - req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewReader(b)) - rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewBufferString(`{"email":"teacher@schools.gov.sg"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() - h.RequestOTP(rec, req) + h.RequestOTP(rec, req) - res := rec.Result() + res := rec.Result() - require.Equal(t, http.StatusInternalServerError, res.StatusCode) - require.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) - require.Equal(t, "Something went wrong. Please try again later.\n", rec.Body.String()) -} + require.Equal(t, http.StatusTooManyRequests, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) -func TestRequestOTP_InternalServerError(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusInternalServerError, Body: io.NopCloser(bytes.NewReader([]byte(`{}`)))}, nil + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeRateLimited, errResp.Code) + require.Equal(t, "Too many requests", errResp.Message) }) - h := &Handler{cfg: config.Default(), client: &http.Client{Transport: rt}} - resetStore() + t.Run("provider domain not allowed", func(t *testing.T) { + resetStore() + h := &Handler{ + cfg: &config.Config{ + AllowedEmailDomains: []string{"@schools.gov.sg"}, + }, + otpProvider: stubOTPProvider{ + requestOTP: func(context.Context, string) (string, error) { + return "", fmt.Errorf("request OTP: %w", otp.ErrDomainNotAllowed) + }, + }, + } - payload := map[string]string{"email": "test@schools.gov.sg"} - b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewBufferString(`{"email":"teacher@schools.gov.sg"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewReader(b)) - rec := httptest.NewRecorder() + h.RequestOTP(rec, req) - h.RequestOTP(rec, req) + res := rec.Result() - res := rec.Result() - require.Equal(t, http.StatusInternalServerError, res.StatusCode) - require.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) - require.Equal(t, "Something went wrong. Please try again later.\n", rec.Body.String()) -} + require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) -func TestRequestOTP_MissingOTPFlowID(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader([]byte(`{}`)))}, nil + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidForm, errResp.Code) + require.Equal(t, "One or more input has an error", errResp.Message) }) - h := &Handler{cfg: config.Default(), client: &http.Client{Transport: rt}} - resetStore() - - payload := map[string]string{"email": "test@schools.gov.sg"} - b, _ := json.Marshal(payload) + t.Run("provider email not allowed", func(t *testing.T) { + resetStore() + h := &Handler{ + cfg: &config.Config{ + AllowedEmailDomains: []string{"@schools.gov.sg"}, + }, + otpProvider: stubOTPProvider{ + requestOTP: func(context.Context, string) (string, error) { + return "", fmt.Errorf("request OTP: %w", otp.ErrEmailNotAllowed) + }, + }, + } - req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewReader(b)) - rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewBufferString(`{"email":"teacher@schools.gov.sg"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() - h.RequestOTP(rec, req) + h.RequestOTP(rec, req) - res := rec.Result() - require.Equal(t, http.StatusInternalServerError, res.StatusCode) - require.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) - require.Equal(t, "Internal server error\n", rec.Body.String()) + res := rec.Result() - cookies := res.Cookies() - require.True(t, len(cookies) == 0) -} + require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) -func TestVerifyOTP_Success(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader([]byte(`{"id": "123"}`)))}, nil + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidForm, errResp.Code) + require.Equal(t, "One or more input has an error", errResp.Message) }) - h := &Handler{cfg: config.Default(), client: &http.Client{Transport: rt}} - resetStore() - store["abc"] = map[string]string{"otp_flow_id": "123"} + t.Run("provider error", func(t *testing.T) { + resetStore() + h := &Handler{ + cfg: &config.Config{ + AllowedEmailDomains: []string{"@schools.gov.sg"}, + }, + otpProvider: stubOTPProvider{ + requestOTP: func(context.Context, string) (string, error) { + return "", fmt.Errorf("request OTP: %w", otp.ErrUnauthorized) + }, + }, + } - payload := map[string]string{"pin": "123456"} - b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/otp/request", bytes.NewBufferString(`{"email":"teacher@schools.gov.sg"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewReader(b)) - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + h.RequestOTP(rec, req) - h.VerifyOTP(rec, req) + res := rec.Result() - res := rec.Result() - require.Equal(t, http.StatusNoContent, res.StatusCode) + require.Equal(t, http.StatusInternalServerError, res.StatusCode) + require.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) + require.Equal(t, "Internal server error\n", rec.Body.String()) + }) } -func TestVerifyOTP_MissingPin(t *testing.T) { - h := &Handler{cfg: config.Default()} - resetStore() - - store["abc"] = map[string]string{"otp_flow_id": "123"} +func TestHandler_VerifyOTP(t *testing.T) { + t.Run("successful response", func(t *testing.T) { + resetStore() + store["abc"] = map[string]string{"otp_flow_id": "flow-123"} + + gotFlowID := "" + gotPIN := "" + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{ + verifyOTP: func(ctx context.Context, flowID, pin string) (string, error) { + gotFlowID = flowID + gotPIN = pin + return "teacher@schools.gov.sg", nil + }, + }, + } - req := httptest.NewRequest(http.MethodPost, "/otp/verify", nil) - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewBufferString(`{"pin":"123456"}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) + rec := httptest.NewRecorder() - h.VerifyOTP(rec, req) + h.VerifyOTP(rec, req) - res := rec.Result() - require.Equal(t, http.StatusBadRequest, res.StatusCode) + res := rec.Result() - var got errorResponse - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) - require.Equal(t, "INVALID_FORM", got.Code) - require.Equal(t, "One or more input has an error", got.Message) -} + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) + require.Equal(t, "flow-123", gotFlowID) + require.Equal(t, "123456", gotPIN) -func TestVerifyOTP_InvalidPin(t *testing.T) { - h := &Handler{cfg: config.Default()} - resetStore() + var resp verifyOTPResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&resp)) + require.Equal(t, "flow-123", resp.ID) + require.Equal(t, "teacher@schools.gov.sg", resp.Email) + }) - store["abc"] = map[string]string{"otp_flow_id": "123"} + t.Run("unsupported media type", func(t *testing.T) { + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{}, + } - payload := map[string]string{"pin": "1234567"} - b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewBufferString(`{"pin":"123456"}`)) + req.Header.Set("Content-Type", "text/plain") + req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewReader(b)) - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + h.VerifyOTP(rec, req) - h.VerifyOTP(rec, req) + res := rec.Result() - res := rec.Result() - require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + require.Equal(t, http.StatusUnsupportedMediaType, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) - var got errorResponse - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) - require.Equal(t, "INVALID_FORM", got.Code) - require.Equal(t, "One or more input has an error", got.Message) -} + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidForm, errResp.Code) + require.Equal(t, "One or more input has an error", errResp.Message) + }) -func TestVerifyOTP_MissingCookie(t *testing.T) { - h := &Handler{cfg: config.Default()} - resetStore() + t.Run("invalid JSON", func(t *testing.T) { + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{}, + } - payload := map[string]string{"pin": "123456"} - b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewBufferString(`{`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewReader(b)) - rec := httptest.NewRecorder() + h.VerifyOTP(rec, req) - h.VerifyOTP(rec, req) + res := rec.Result() - res := rec.Result() - require.Equal(t, http.StatusInternalServerError, res.StatusCode) - require.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) - require.Equal(t, "Internal server error\n", rec.Body.String()) -} + require.Equal(t, http.StatusBadRequest, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) -func TestVerifyOTP_MissingSession(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader([]byte(`{}`)))}, nil + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidForm, errResp.Code) + require.Equal(t, "One or more input has an error", errResp.Message) }) - h := &Handler{cfg: config.Default(), client: &http.Client{Transport: rt}} - resetStore() - - payload := map[string]string{"pin": "123456"} - b, _ := json.Marshal(payload) + t.Run("empty pin", func(t *testing.T) { + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{}, + } - req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewReader(b)) - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewBufferString(`{"pin":""}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) + rec := httptest.NewRecorder() - h.VerifyOTP(rec, req) + h.VerifyOTP(rec, req) - res := rec.Result() - require.Equal(t, http.StatusUnauthorized, res.StatusCode) + res := rec.Result() - var got errorResponse - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) - require.Equal(t, "AUTHORIZATION_FAILED", got.Code) - require.Equal(t, "Failed to authenticate session.", got.Message) -} + require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) -func TestVerifyOTP_Timeout(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - select { - case <-req.Context().Done(): - return nil, req.Context().Err() - case <-time.After(200 * time.Millisecond): - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(`{"id":"123"}`))), - }, nil - } + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidForm, errResp.Code) + require.Equal(t, "One or more input has an error", errResp.Message) }) - h := &Handler{cfg: config.Default(), client: &http.Client{Timeout: 10 * time.Millisecond, Transport: rt}} - resetStore() - store["abc"] = map[string]string{"otp_flow_id": "123"} + t.Run("invalid pin length", func(t *testing.T) { + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{}, + } - payload := map[string]string{"pin": "123456"} - b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewBufferString(`{"pin":"12345"}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewReader(b)) - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + h.VerifyOTP(rec, req) - h.VerifyOTP(rec, req) + res := rec.Result() - res := rec.Result() - require.Equal(t, http.StatusGatewayTimeout, res.StatusCode) - require.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) - require.Equal(t, "Request timeout. Please try again later.\n", rec.Body.String()) -} + require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) -func TestVerifyOTP_Unauthorized(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusUnauthorized, Body: io.NopCloser(bytes.NewReader([]byte(`{}`)))}, nil + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidForm, errResp.Code) + require.Equal(t, "One or more input has an error", errResp.Message) }) - h := &Handler{cfg: config.Default(), client: &http.Client{Transport: rt}} - resetStore() + t.Run("missing cookie", func(t *testing.T) { + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{}, + } - store["abc"] = map[string]string{"otp_flow_id": "123"} + req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewBufferString(`{"pin":"123456"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() - payload := map[string]string{"pin": "123456"} - b, _ := json.Marshal(payload) + h.VerifyOTP(rec, req) - req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewReader(b)) - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + res := rec.Result() - h.VerifyOTP(rec, req) + require.Equal(t, http.StatusInternalServerError, res.StatusCode) + }) - res := rec.Result() - require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + t.Run("missing session", func(t *testing.T) { + resetStore() + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{}, + } - var got errorResponse - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) - require.Equal(t, "AUTHORIZATION_FAILED", got.Code) - require.Equal(t, "Failed to authenticate session.", got.Message) -} + req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewBufferString(`{"pin":"123456"}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) + rec := httptest.NewRecorder() -func TestVerifyOTP_NotFound(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusNotFound, Body: io.NopCloser(bytes.NewReader([]byte(`{}`)))}, nil - }) + h.VerifyOTP(rec, req) + + res := rec.Result() - h := &Handler{cfg: config.Default(), client: &http.Client{Transport: rt}} - resetStore() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) - store["abc"] = map[string]string{"otp_flow_id": "123"} + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidAuth, errResp.Code) + require.Equal(t, "Failed to authenticate session.", errResp.Message) + }) - payload := map[string]string{"pin": "123456"} - b, _ := json.Marshal(payload) + t.Run("provider invalid pin", func(t *testing.T) { + resetStore() + store["abc"] = map[string]string{"otp_flow_id": "flow-123"} + + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{ + verifyOTP: func(context.Context, string, string) (string, error) { + return "", fmt.Errorf("verify OTP: %w", otp.ErrInvalidPIN) + }, + }, + } - req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewReader(b)) - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewBufferString(`{"pin":"123456"}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) + rec := httptest.NewRecorder() - h.VerifyOTP(rec, req) + h.VerifyOTP(rec, req) - res := rec.Result() - require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + res := rec.Result() - var got errorResponse - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) - require.Equal(t, "AUTHORIZATION_FAILED", got.Code) - require.Equal(t, "Failed to authenticate session.", got.Message) -} + require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) -func TestVerifyOTP_BadRequest(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusBadRequest, Body: io.NopCloser(bytes.NewReader([]byte(`{}`)))}, nil + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidAuth, errResp.Code) + require.Equal(t, "Failed to authenticate session.", errResp.Message) }) - h := &Handler{cfg: config.Default(), client: &http.Client{Transport: rt}} - resetStore() + t.Run("provider flow expired", func(t *testing.T) { + resetStore() + store["abc"] = map[string]string{"otp_flow_id": "flow-123"} + + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{ + verifyOTP: func(context.Context, string, string) (string, error) { + return "", fmt.Errorf("verify OTP: %w", otp.ErrFlowExpired) + }, + }, + } - store["abc"] = map[string]string{"otp_flow_id": "123"} + req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewBufferString(`{"pin":"123456"}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) + rec := httptest.NewRecorder() - payload := map[string]string{"pin": "123456"} - b, _ := json.Marshal(payload) + h.VerifyOTP(rec, req) - req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewReader(b)) - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + res := rec.Result() - h.VerifyOTP(rec, req) + require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode) + require.Equal(t, MIMEApplicationJSONCharsetUTF8, res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) - res := rec.Result() - require.Equal(t, http.StatusInternalServerError, res.StatusCode) - require.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) - require.Equal(t, "Internal server error\n", rec.Body.String()) -} - -func TestVerifyOTP_InternalServerError(t *testing.T) { - rt := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusInternalServerError, Body: io.NopCloser(bytes.NewReader([]byte(`{}`)))}, nil + var errResp errorResponse + require.NoError(t, json.NewDecoder(res.Body).Decode(&errResp)) + require.Equal(t, ErrorCodeInvalidAuth, errResp.Code) + require.Equal(t, "Failed to authenticate session.", errResp.Message) }) - h := &Handler{cfg: config.Default(), client: &http.Client{Transport: rt}} - resetStore() - - store["abc"] = map[string]string{"otp_flow_id": "123"} + t.Run("provider error", func(t *testing.T) { + resetStore() + store["abc"] = map[string]string{"otp_flow_id": "flow-123"} + + h := &Handler{ + cfg: &config.Config{}, + otpProvider: stubOTPProvider{ + verifyOTP: func(context.Context, string, string) (string, error) { + return "", fmt.Errorf("verify OTP: %w", otp.ErrUnauthorized) + }, + }, + } - payload := map[string]string{"pin": "123456"} - b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewBufferString(`{"pin":"123456"}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/otp/verify", bytes.NewReader(b)) - req.AddCookie(&http.Cookie{Name: "session_id", Value: "abc"}) - rec := httptest.NewRecorder() + h.VerifyOTP(rec, req) - h.VerifyOTP(rec, req) + res := rec.Result() - res := rec.Result() - require.Equal(t, http.StatusInternalServerError, res.StatusCode) - require.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) - require.Equal(t, "Internal server error\n", rec.Body.String()) + require.Equal(t, http.StatusInternalServerError, res.StatusCode) + require.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) + require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) + require.Equal(t, "Internal server error\n", rec.Body.String()) + }) } diff --git a/server/internal/otp/bimap/bimap.go b/server/internal/otp/bimap/bimap.go new file mode 100644 index 0000000..d2b79ea --- /dev/null +++ b/server/internal/otp/bimap/bimap.go @@ -0,0 +1,55 @@ +package bimap + +type BiMap[K, V comparable] struct { + forward map[K]V + backward map[V]K +} + +func New[K, V comparable]() *BiMap[K, V] { + return &BiMap[K, V]{ + forward: make(map[K]V), + backward: make(map[V]K), + } +} + +func (b *BiMap[K, V]) Put(k K, v V) { + if oldV, ok := b.forward[k]; ok { + delete(b.backward, oldV) + } + if oldK, ok := b.backward[v]; ok { + delete(b.forward, oldK) + } + + b.forward[k] = v + b.backward[v] = k +} + +func (b *BiMap[K, V]) Get(k K) (V, bool) { + v, ok := b.forward[k] + return v, ok +} + +func (b *BiMap[K, V]) GetByValue(v V) (K, bool) { + k, ok := b.backward[v] + return k, ok +} + +func (b *BiMap[K, V]) Delete(k K) { + v, ok := b.forward[k] + if !ok { + return + } + + delete(b.forward, k) + delete(b.backward, v) +} + +func (b *BiMap[K, V]) DeleteByValue(v V) { + k, ok := b.backward[v] + if !ok { + return + } + + delete(b.forward, k) + delete(b.backward, v) +} diff --git a/server/internal/otp/mock.go b/server/internal/otp/mock.go new file mode 100644 index 0000000..f434d2a --- /dev/null +++ b/server/internal/otp/mock.go @@ -0,0 +1,99 @@ +package otp + +import ( + "context" + "errors" + "strings" + "sync" + + "github.com/String-sg/teacher-workspace/server/internal/otp/bimap" + "github.com/String-sg/teacher-workspace/server/pkg/random" +) + +const ( + mockFlowIDLength = 8 + mockSuccessPIN = "112233" + mockExpiredPIN = "223344" +) + +// MockProvider implements [Provider] for local and lower-environment use. +type MockProvider struct { + mu sync.Mutex + + allowedEmails map[string]struct{} + + // store maps emails to flow IDs and vice versa. + store *bimap.BiMap[string, string] +} + +// NewMockProvider creates a [MockProvider] with an allowlist of exact email matches. +func NewMockProvider(emails []string) *MockProvider { + allowedEmails := make(map[string]struct{}, len(emails)) + for _, email := range emails { + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" { + continue + } + + allowedEmails[email] = struct{}{} + } + + return &MockProvider{ + allowedEmails: allowedEmails, + store: bimap.New[string, string](), + } +} + +// RequestOTP implements [Provider.RequestOTP]. +func (p *MockProvider) RequestOTP(_ context.Context, email string) (string, error) { + email = strings.TrimSpace(strings.ToLower(email)) + if _, ok := p.allowedEmails[email]; !ok { + return "", ErrEmailNotAllowed + } + + p.mu.Lock() + defer p.mu.Unlock() + + if _, ok := p.store.Get(email); ok { + p.store.Delete(email) + } + + var flowID string + for range 3 { + candidateFlowID := random.Base58(mockFlowIDLength) + if _, exists := p.store.GetByValue(candidateFlowID); exists { + continue + } + + flowID = candidateFlowID + break + } + if flowID == "" { + return "", errors.New("mock: failed to generate unique flow ID after 3 attempts") + } + + p.store.Put(email, flowID) + + return flowID, nil +} + +// VerifyOTP implements [Provider.VerifyOTP]. +func (p *MockProvider) VerifyOTP(_ context.Context, flowID, pin string) (string, error) { + p.mu.Lock() + defer p.mu.Unlock() + + email, ok := p.store.GetByValue(flowID) + if !ok { + return "", ErrFlowExpired + } + + switch pin { + case mockSuccessPIN: + p.store.DeleteByValue(flowID) + return email, nil + case mockExpiredPIN: + return "", ErrFlowExpired + default: + return "", ErrInvalidPIN + } +} diff --git a/server/internal/otp/mock_test.go b/server/internal/otp/mock_test.go new file mode 100644 index 0000000..6f11d75 --- /dev/null +++ b/server/internal/otp/mock_test.go @@ -0,0 +1,171 @@ +package otp + +import ( + "context" + "errors" + "testing" +) + +func TestMockProvider_RequestOTP(t *testing.T) { + t.Run("whitelisted email returns flow ID", func(t *testing.T) { + p := NewMockProvider([]string{"whitelisted@example.com"}) + + flowID, err := p.RequestOTP(context.Background(), " whitelisted@example.com ") + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + if want, got := mockFlowIDLength, len(flowID); want != got { + t.Fatalf("want flow ID length: %d; got %d", want, got) + } + }) + + t.Run("non-whitelisted email is rejected", func(t *testing.T) { + p := NewMockProvider([]string{"whitelisted@example.com"}) + + flowID, err := p.RequestOTP(context.Background(), "other@example.com") + if !errors.Is(err, ErrEmailNotAllowed) { + t.Fatalf("want ErrEmailNotAllowed; got %v", err) + } + if want := ""; want != flowID { + t.Fatalf("want flow ID: %q; got %q", want, flowID) + } + }) + + t.Run("repeated request with same email returns new flow ID", func(t *testing.T) { + p := NewMockProvider([]string{"whitelisted@example.com"}) + + oldFlowID, err := p.RequestOTP(context.Background(), "whitelisted@example.com") + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + + newFlowID, err := p.RequestOTP(context.Background(), "whitelisted@example.com") + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + + if oldFlowID == newFlowID { + t.Fatalf("want distinct flow IDs; got %q", oldFlowID) + } + }) +} + +func TestMockProvider_VerifyOTP(t *testing.T) { + t.Run("success PIN returns email and consumes flow", func(t *testing.T) { + p := NewMockProvider([]string{"whitelisted@example.com"}) + p.store.Put("whitelisted@example.com", "flow-123") + + email, err := p.VerifyOTP(context.Background(), "flow-123", mockSuccessPIN) + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + if want := "whitelisted@example.com"; want != email { + t.Fatalf("want email: %q; got %q", want, email) + } + + email, err = p.VerifyOTP(context.Background(), "flow-123", mockSuccessPIN) + if !errors.Is(err, ErrFlowExpired) { + t.Fatalf("want ErrFlowExpired; got %v", err) + } + if want := ""; want != email { + t.Fatalf("want email: %q; got %q", want, email) + } + }) + + t.Run("expired PIN returns ErrFlowExpired without consuming flow", func(t *testing.T) { + p := NewMockProvider([]string{"whitelisted@example.com"}) + p.store.Put("whitelisted@example.com", "flow-123") + + email, err := p.VerifyOTP(context.Background(), "flow-123", mockExpiredPIN) + if !errors.Is(err, ErrFlowExpired) { + t.Fatalf("want ErrFlowExpired; got %v", err) + } + if want := ""; want != email { + t.Fatalf("want email: %q; got %q", want, email) + } + + email, err = p.VerifyOTP(context.Background(), "flow-123", mockExpiredPIN) + if !errors.Is(err, ErrFlowExpired) { + t.Fatalf("want ErrFlowExpired; got %v", err) + } + if want := ""; want != email { + t.Fatalf("want email: %q; got %q", want, email) + } + + email, err = p.VerifyOTP(context.Background(), "flow-123", mockSuccessPIN) + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + if want := "whitelisted@example.com"; want != email { + t.Fatalf("want email: %q; got %q", want, email) + } + }) + + t.Run("invalid PIN returns ErrInvalidPIN without consuming flow", func(t *testing.T) { + p := NewMockProvider([]string{"whitelisted@example.com"}) + p.store.Put("whitelisted@example.com", "flow-123") + + email, err := p.VerifyOTP(context.Background(), "flow-123", "999999") + if !errors.Is(err, ErrInvalidPIN) { + t.Fatalf("want ErrInvalidPIN; got %v", err) + } + if want := ""; want != email { + t.Fatalf("want email: %q; got %q", want, email) + } + + email, err = p.VerifyOTP(context.Background(), "flow-123", mockSuccessPIN) + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + if want := "whitelisted@example.com"; want != email { + t.Fatalf("want email: %q; got %q", want, email) + } + }) + + t.Run("unknown flow returns ErrFlowExpired", func(t *testing.T) { + p := NewMockProvider([]string{"whitelisted@example.com"}) + + email, err := p.VerifyOTP(context.Background(), "flow-unknown", mockSuccessPIN) + + if !errors.Is(err, ErrFlowExpired) { + t.Fatalf("want ErrFlowExpired; got %v", err) + } + if want := ""; want != email { + t.Fatalf("want email: %q; got %q", want, email) + } + }) +} + +func TestMockProvider_FlowLifecycle(t *testing.T) { + p := NewMockProvider([]string{"whitelisted@example.com"}) + + oldFlowID, err := p.RequestOTP(context.Background(), "whitelisted@example.com") + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + + newFlowID, err := p.RequestOTP(context.Background(), "whitelisted@example.com") + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + + if oldFlowID == newFlowID { + t.Fatalf("want distinct flow IDs; got %q", oldFlowID) + } + + email, err := p.VerifyOTP(context.Background(), oldFlowID, mockSuccessPIN) + if !errors.Is(err, ErrFlowExpired) { + t.Fatalf("want ErrFlowExpired; got %v", err) + } + if want := ""; want != email { + t.Fatalf("want email: %q; got %q", want, email) + } + + email, err = p.VerifyOTP(context.Background(), newFlowID, mockSuccessPIN) + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + if want := "whitelisted@example.com"; want != email { + t.Fatalf("want email: %q; got %q", want, email) + } +} diff --git a/server/internal/otp/otp.go b/server/internal/otp/otp.go new file mode 100644 index 0000000..427667c --- /dev/null +++ b/server/internal/otp/otp.go @@ -0,0 +1,39 @@ +package otp + +import ( + "context" + "errors" +) + +var ( + // ErrRateLimited indicates the request exceeded the allowed request rate. + ErrRateLimited = errors.New("too many requests") + + // ErrDomainNotAllowed indicates the email domain is not permitted. + ErrDomainNotAllowed = errors.New("email domain not allowed") + + // ErrEmailNotAllowed indicates the email is not permitted. + ErrEmailNotAllowed = errors.New("email not allowed") + + // ErrUnauthorized indicates the underlying provider rejected the request due to invalid or missing credentials. + ErrUnauthorized = errors.New("unauthorized") + + // ErrFlowExpired indicates the OTP flow is no longer valid. + ErrFlowExpired = errors.New("flow expired") + + // ErrInvalidPIN indicates the provided PIN does not match the flow. + ErrInvalidPIN = errors.New("invalid PIN") +) + +// Provider manages OTP request and verification flows. +type Provider interface { + // RequestOTP sends an OTP to the given email and returns a flow ID. + // May return [ErrRateLimited], [ErrDomainNotAllowed], [ErrEmailNotAllowed], or [ErrUnauthorized]. + // Use [errors.Is] to check; other errors may be returned for unexpected failures. + RequestOTP(ctx context.Context, email string) (string, error) + + // VerifyOTP validates a PIN against the given flow and returns the verified email. + // May return [ErrInvalidPIN], [ErrFlowExpired], or [ErrUnauthorized]. + // Use [errors.Is] to check; other errors may be returned for unexpected failures. + VerifyOTP(ctx context.Context, flowID, pin string) (string, error) +} diff --git a/server/internal/otp/otpaas.go b/server/internal/otp/otpaas.go new file mode 100644 index 0000000..511c1ec --- /dev/null +++ b/server/internal/otp/otpaas.go @@ -0,0 +1,194 @@ +package otp + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// OTPaaSProvider implements [Provider] using the OTPaaS API. +type OTPaaSProvider struct { + host string + appID string + appNamespace string + token string + + client *http.Client +} + +// NewOTPaaSProvider creates an [OTPaaSProvider] with the given configuration. +func NewOTPaaSProvider(host, appID, appNamespace, secret string, timeout time.Duration) *OTPaaSProvider { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write([]byte(appID)) + + sig := hex.EncodeToString(mac.Sum(nil)) + + payload := appNamespace + ":" + appID + ":" + sig + token := base64.StdEncoding.EncodeToString([]byte(payload)) + + return &OTPaaSProvider{ + host: host, + appID: appID, + appNamespace: appNamespace, + token: token, + + client: &http.Client{Timeout: timeout}, + } +} + +// RequestOTP implements [Provider.RequestOTP]. +func (p *OTPaaSProvider) RequestOTP(ctx context.Context, email string) (string, error) { + reqBody, err := json.Marshal(struct { + Email string `json:"email"` + }{ + Email: strings.TrimSpace(strings.ToLower(email)), + }) + if err != nil { + return "", fmt.Errorf("otpaas: failed to marshal request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.host+"/otp", bytes.NewReader(reqBody)) + if err != nil { + return "", fmt.Errorf("otpaas: failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+p.token) + req.Header.Set("X-App-Id", p.appID) + req.Header.Set("X-App-Namespace", p.appNamespace) + + resp, err := p.client.Do(req) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return "", fmt.Errorf("otpaas: request timed out: %w", err) + } + return "", fmt.Errorf("otpaas: failed to send request: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("otpaas: failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var detail struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(respBody, &detail) + + switch resp.StatusCode { + case http.StatusBadRequest: + if detail.Code == 2008 { + return "", fmt.Errorf("otpaas: %w", ErrRateLimited) + } + return "", fmt.Errorf("otpaas: bad request (code %d): %q", detail.Code, detail.Message) + case http.StatusUnauthorized: + return "", fmt.Errorf("otpaas: %w", ErrUnauthorized) + case http.StatusForbidden: + if detail.Code == 2005 { + return "", fmt.Errorf("otpaas: %w", ErrDomainNotAllowed) + } + return "", fmt.Errorf("otpaas: forbidden (code %d): %q", detail.Code, detail.Message) + default: + return "", fmt.Errorf("otpaas: unexpected status %d (code %d): %q", resp.StatusCode, detail.Code, detail.Message) + } + } + + var result struct { + ID string `json:"id"` + } + if err := json.Unmarshal(respBody, &result); err != nil { + return "", fmt.Errorf("otpaas: failed to unmarshal response body: %w", err) + } + + if result.ID == "" { + return "", errors.New("otpaas: returned empty flow id") + } + + return result.ID, nil +} + +// VerifyOTP implements [Provider.VerifyOTP]. +func (p *OTPaaSProvider) VerifyOTP(ctx context.Context, flowID string, pin string) (string, error) { + reqBody, err := json.Marshal(struct { + PIN string `json:"pin"` + }{ + PIN: pin, + }) + if err != nil { + return "", fmt.Errorf("otpaas: failed to marshal request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPut, p.host+"/otp/"+flowID, bytes.NewReader(reqBody)) + if err != nil { + return "", fmt.Errorf("otpaas: failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+p.token) + req.Header.Set("X-App-Id", p.appID) + req.Header.Set("X-App-Namespace", p.appNamespace) + + resp, err := p.client.Do(req) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return "", fmt.Errorf("otpaas: request timed out: %w", err) + } + return "", fmt.Errorf("otpaas: failed to send request: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("otpaas: failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var detail struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(respBody, &detail) + + switch resp.StatusCode { + case http.StatusUnauthorized: + if detail.Code == 1006 { + return "", ErrInvalidPIN + } + return "", ErrUnauthorized + case http.StatusNotFound: + return "", ErrFlowExpired + default: + return "", fmt.Errorf("otpaas: unexpected status %d (code %d): %q", resp.StatusCode, detail.Code, detail.Message) + } + } + + var result struct { + Email string `json:"email"` + } + if err := json.Unmarshal(respBody, &result); err != nil { + return "", fmt.Errorf("otpaas: failed to unmarshal response body: %w", err) + } + + if result.Email == "" { + return "", errors.New("otpaas: returned empty email") + } + + return result.Email, nil +} diff --git a/server/internal/otp/otpaas_test.go b/server/internal/otp/otpaas_test.go new file mode 100644 index 0000000..47f7157 --- /dev/null +++ b/server/internal/otp/otpaas_test.go @@ -0,0 +1,469 @@ +package otp + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "testing" + "time" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func newTestOTPaaSProvider(rt http.RoundTripper) *OTPaaSProvider { + p := NewOTPaaSProvider("https://otp.example.com", "app-id", "app-namespace", "secret", 5*time.Second) + p.client = &http.Client{Transport: rt} + return p +} + +func TestOTPaaSProvider_RequestOTP(t *testing.T) { + t.Run("sends expected request and returns flow ID", func(t *testing.T) { + var captured *http.Request + var capturedBody string + p := newTestOTPaaSProvider(roundTripFunc(func(r *http.Request) (*http.Response, error) { + captured = r.Clone(r.Context()) + + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + capturedBody = string(body) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(`{"id":"flow-42"}`))), + }, nil + })) + + id, err := p.RequestOTP(context.Background(), "xyz@example.com") + + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + if want := "flow-42"; want != id { + t.Errorf("want flow ID: %q; got %q", want, id) + } + + if captured == nil { + t.Fatalf("want captured request; got nil") + } + if want, got := http.MethodPost, captured.Method; want != got { + t.Errorf("want method: %q; got %q", want, got) + } + if want, got := "https://otp.example.com/otp", captured.URL.String(); want != got { + t.Errorf("want URL: %q; got %q", want, got) + } + if want, got := "application/json", captured.Header.Get("Content-Type"); want != got { + t.Errorf("want Content-Type header: %q; got %q", want, got) + } + if want, got := "Bearer "+p.token, captured.Header.Get("Authorization"); want != got { + t.Errorf("want Authorization header: %q; got %q", want, got) + } + if want, got := "app-id", captured.Header.Get("X-App-Id"); want != got { + t.Errorf("want X-App-Id header: %q; got %q", want, got) + } + if want, got := "app-namespace", captured.Header.Get("X-App-Namespace"); want != got { + t.Errorf("want X-App-Namespace header: %q; got %q", want, got) + } + if want := `{"email":"xyz@example.com"}`; want != capturedBody { + t.Errorf("want request body: %q; got %q", want, capturedBody) + } + }) + + t.Run("bad request status with rate limit code returns ErrRateLimited", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(bytes.NewReader([]byte(`{"code":2008,"message":"rate limited"}`))), + }, nil + })) + + id, err := p.RequestOTP(context.Background(), "xyz@example.com") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != id { + t.Errorf("want flow ID: %q; got %q", want, id) + } + if want, got := true, errors.Is(err, ErrRateLimited); want != got { + t.Errorf("want ErrRateLimited match: %v; got %v", want, got) + } + }) + + t.Run("bad request status with unknown code returns error", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(bytes.NewReader([]byte(`{"code":9999,"message":"something else"}`))), + }, nil + })) + + id, err := p.RequestOTP(context.Background(), "xyz@example.com") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != id { + t.Errorf("want flow ID: %q; got %q", want, id) + } + if want, got := false, errors.Is(err, ErrRateLimited); want != got { + t.Errorf("want ErrRateLimited match: %v; got %v", want, got) + } + }) + + t.Run("unauthorized status returns ErrUnauthorized", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: io.NopCloser(bytes.NewReader([]byte(`{}`))), + }, nil + })) + + id, err := p.RequestOTP(context.Background(), "xyz@example.com") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != id { + t.Errorf("want flow ID: %q; got %q", want, id) + } + if want, got := true, errors.Is(err, ErrUnauthorized); want != got { + t.Errorf("want ErrUnauthorized match: %v; got %v", want, got) + } + }) + + t.Run("forbidden status with domain not allowed code returns ErrDomainNotAllowed", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusForbidden, + Body: io.NopCloser(bytes.NewReader([]byte(`{"code":2005,"message":"domain not allowed"}`))), + }, nil + })) + + id, err := p.RequestOTP(context.Background(), "xyz@example.com") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != id { + t.Errorf("want flow ID: %q; got %q", want, id) + } + if want, got := true, errors.Is(err, ErrDomainNotAllowed); want != got { + t.Errorf("want ErrDomainNotAllowed match: %v; got %v", want, got) + } + }) + + t.Run("forbidden status with unknown code returns error", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusForbidden, + Body: io.NopCloser(bytes.NewReader([]byte(`{"code":1234,"message":"other"}`))), + }, nil + })) + + id, err := p.RequestOTP(context.Background(), "xyz@example.com") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != id { + t.Errorf("want flow ID: %q; got %q", want, id) + } + if want, got := false, errors.Is(err, ErrDomainNotAllowed); want != got { + t.Errorf("want ErrDomainNotAllowed match: %v; got %v", want, got) + } + }) + + t.Run("unexpected status returns error", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(bytes.NewReader([]byte(`{"code":0,"message":"boom"}`))), + }, nil + })) + + id, err := p.RequestOTP(context.Background(), "xyz@example.com") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != id { + t.Errorf("want flow ID: %q; got %q", want, id) + } + if want, got := false, errors.Is(err, ErrRateLimited); want != got { + t.Errorf("want ErrRateLimited match: %v; got %v", want, got) + } + if want, got := false, errors.Is(err, ErrDomainNotAllowed); want != got { + t.Errorf("want ErrDomainNotAllowed match: %v; got %v", want, got) + } + if want, got := false, errors.Is(err, ErrUnauthorized); want != got { + t.Errorf("want ErrUnauthorized match: %v; got %v", want, got) + } + }) + + t.Run("request timeout", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(r *http.Request) (*http.Response, error) { + <-r.Context().Done() + return nil, r.Context().Err() + })) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + id, err := p.RequestOTP(ctx, "xyz@example.com") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != id { + t.Errorf("want flow ID: %q; got %q", want, id) + } + if want, got := true, errors.Is(err, context.DeadlineExceeded); want != got { + t.Errorf("want context deadline exceeded match: %v; got %v", want, got) + } + }) + + t.Run("transport error", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("connection refused") + })) + + id, err := p.RequestOTP(context.Background(), "xyz@example.com") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != id { + t.Errorf("want flow ID: %q; got %q", want, id) + } + if want, got := false, errors.Is(err, context.DeadlineExceeded); want != got { + t.Errorf("want context deadline exceeded match: %v; got %v", want, got) + } + }) + + t.Run("empty flow ID returns error", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(`{"id":""}`))), + }, nil + })) + + id, err := p.RequestOTP(context.Background(), "xyz@example.com") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != id { + t.Errorf("want flow ID: %q; got %q", want, id) + } + }) +} + +func TestOTPaaSProvider_VerifyOTP(t *testing.T) { + t.Run("sends expected request and returns email", func(t *testing.T) { + var captured *http.Request + var capturedBody string + p := newTestOTPaaSProvider(roundTripFunc(func(r *http.Request) (*http.Response, error) { + captured = r.Clone(r.Context()) + + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + capturedBody = string(body) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(`{"email":"xyz@example.com"}`))), + }, nil + })) + + email, err := p.VerifyOTP(context.Background(), "flow-42", "123456") + + if err != nil { + t.Fatalf("want: nil; got %v", err) + } + if want := "xyz@example.com"; want != email { + t.Errorf("want email: %q; got %q", want, email) + } + + if captured == nil { + t.Fatalf("want captured request; got nil") + } + if want, got := http.MethodPut, captured.Method; want != got { + t.Errorf("want method: %q; got %q", want, got) + } + if want, got := "https://otp.example.com/otp/flow-42", captured.URL.String(); want != got { + t.Errorf("want URL: %q; got %q", want, got) + } + if want, got := "application/json", captured.Header.Get("Content-Type"); want != got { + t.Errorf("want Content-Type header: %q; got %q", want, got) + } + if want, got := "Bearer "+p.token, captured.Header.Get("Authorization"); want != got { + t.Errorf("want Authorization header: %q; got %q", want, got) + } + if want, got := "app-id", captured.Header.Get("X-App-Id"); want != got { + t.Errorf("want X-App-Id header: %q; got %q", want, got) + } + if want, got := "app-namespace", captured.Header.Get("X-App-Namespace"); want != got { + t.Errorf("want X-App-Namespace header: %q; got %q", want, got) + } + if want := `{"pin":"123456"}`; want != capturedBody { + t.Errorf("want request body: %q; got %q", want, capturedBody) + } + }) + + t.Run("empty email returns error", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(`{"email":""}`))), + }, nil + })) + + email, err := p.VerifyOTP(context.Background(), "flow-42", "123456") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != email { + t.Errorf("want email: %q; got %q", want, email) + } + }) + + t.Run("unauthorized status with invalid PIN code returns ErrInvalidPIN", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: io.NopCloser(bytes.NewReader([]byte(`{"code":1006,"message":"invalid pin"}`))), + }, nil + })) + + email, err := p.VerifyOTP(context.Background(), "flow-42", "000000") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != email { + t.Errorf("want email: %q; got %q", want, email) + } + if want, got := true, errors.Is(err, ErrInvalidPIN); want != got { + t.Errorf("want ErrInvalidPIN match: %v; got %v", want, got) + } + }) + + t.Run("unauthorized status with unknown code returns ErrUnauthorized", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: io.NopCloser(bytes.NewReader([]byte(`{"code":9999,"message":"bad token"}`))), + }, nil + })) + + email, err := p.VerifyOTP(context.Background(), "flow-42", "123456") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != email { + t.Errorf("want email: %q; got %q", want, email) + } + if want, got := true, errors.Is(err, ErrUnauthorized); want != got { + t.Errorf("want ErrUnauthorized match: %v; got %v", want, got) + } + }) + + t.Run("not found status returns ErrFlowExpired", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewReader([]byte(`{}`))), + }, nil + })) + + email, err := p.VerifyOTP(context.Background(), "flow-42", "123456") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != email { + t.Errorf("want email: %q; got %q", want, email) + } + if want, got := true, errors.Is(err, ErrFlowExpired); want != got { + t.Errorf("want ErrFlowExpired match: %v; got %v", want, got) + } + }) + + t.Run("unexpected status returns error", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(bytes.NewReader([]byte(`{"code":0,"message":"oops"}`))), + }, nil + })) + + email, err := p.VerifyOTP(context.Background(), "flow-42", "123456") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != email { + t.Errorf("want email: %q; got %q", want, email) + } + if want, got := false, errors.Is(err, ErrInvalidPIN); want != got { + t.Errorf("want ErrInvalidPIN match: %v; got %v", want, got) + } + if want, got := false, errors.Is(err, ErrUnauthorized); want != got { + t.Errorf("want ErrUnauthorized match: %v; got %v", want, got) + } + if want, got := false, errors.Is(err, ErrFlowExpired); want != got { + t.Errorf("want ErrFlowExpired match: %v; got %v", want, got) + } + }) + + t.Run("request timeout", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(r *http.Request) (*http.Response, error) { + <-r.Context().Done() + return nil, r.Context().Err() + })) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + email, err := p.VerifyOTP(ctx, "flow-42", "123456") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != email { + t.Errorf("want email: %q; got %q", want, email) + } + if want, got := true, errors.Is(err, context.DeadlineExceeded); want != got { + t.Errorf("want context deadline exceeded match: %v; got %v", want, got) + } + }) + + t.Run("transport error", func(t *testing.T) { + p := newTestOTPaaSProvider(roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("connection refused") + })) + + email, err := p.VerifyOTP(context.Background(), "flow-42", "123456") + + if err == nil { + t.Fatalf("want err; got nil") + } + if want := ""; want != email { + t.Errorf("want email: %q; got %q", want, email) + } + if want, got := false, errors.Is(err, context.DeadlineExceeded); want != got { + t.Errorf("want context deadline exceeded match: %v; got %v", want, got) + } + }) +}