diff --git a/server/internal/config/config_test.go b/server/internal/config/config_test.go new file mode 100644 index 0000000..e007441 --- /dev/null +++ b/server/internal/config/config_test.go @@ -0,0 +1,365 @@ +package config + +import ( + "strings" + "testing" +) + +func defaultFixtureConfig() *Config { + return fixtureConfigWithProvider(OTPProviderOTPaaS) +} + +func fixtureConfigWithProvider(provider Provider) *Config { + cfg := Default() + cfg.OTP.Provider = provider + + switch provider { + case OTPProviderOTPaaS: + cfg.OTP.OTPaaS.ID = "app-id" + cfg.OTP.OTPaaS.Namespace = "app-namespace" + cfg.OTP.OTPaaS.Secret = "secret" + case OTPProviderMock: + cfg.OTP.Mock.AllowedEmails = []string{"whitelisted@example.com"} + } + + return cfg +} + +func TestConfig_Validate(t *testing.T) { + t.Run("valid environment", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Environment = EnvironmentProduction + cfg.BundleDirectory = t.TempDir() + + err := cfg.Validate() + + if err != nil { + t.Fatalf("want nil; got %v", err) + } + }) + + t.Run("invalid environment", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Environment = "staging" + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_ENV"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("invalid provider", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.OTP.Provider = "sms" + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_OTP_PROVIDER"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("empty allowed email domains", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.AllowedEmailDomains = nil + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_ALLOWED_EMAIL_DOMAINS"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("server port zero", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Server.Port = 0 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_SERVER_PORT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("server port exceeds max value", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Server.Port = 65536 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_SERVER_PORT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("server zero read header timeout", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Server.ReadHeaderTimeout = 0 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_SERVER_READ_HEADER_TIMEOUT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("server non-positive read header timeout", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Server.ReadHeaderTimeout = -1 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_SERVER_READ_HEADER_TIMEOUT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("server zero read timeout", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Server.ReadTimeout = 0 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_SERVER_READ_TIMEOUT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("server non-positive read timeout", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Server.ReadTimeout = -1 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_SERVER_READ_TIMEOUT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("server zero write timeout", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Server.WriteTimeout = 0 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_SERVER_WRITE_TIMEOUT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("server non-positive write timeout", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Server.WriteTimeout = -1 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_SERVER_WRITE_TIMEOUT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("server zero idle timeout", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Server.IdleTimeout = 0 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_SERVER_IDLE_TIMEOUT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("server non-positive idle timeout", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Server.IdleTimeout = -1 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_SERVER_IDLE_TIMEOUT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("valid otpaas config", func(t *testing.T) { + cfg := fixtureConfigWithProvider(OTPProviderOTPaaS) + + err := cfg.Validate() + + if err != nil { + t.Fatalf("want nil; got %v", err) + } + }) + + t.Run("otpaas missing host", func(t *testing.T) { + cfg := fixtureConfigWithProvider(OTPProviderOTPaaS) + cfg.OTP.OTPaaS.Host = "" + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_OTPAAS_HOST"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("otpaas missing app ID", func(t *testing.T) { + cfg := fixtureConfigWithProvider(OTPProviderOTPaaS) + cfg.OTP.OTPaaS.ID = "" + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_OTPAAS_ID"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("otpaas missing app namespace", func(t *testing.T) { + cfg := fixtureConfigWithProvider(OTPProviderOTPaaS) + cfg.OTP.OTPaaS.Namespace = "" + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_OTPAAS_NAMESPACE"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("otpaas missing secret", func(t *testing.T) { + cfg := fixtureConfigWithProvider(OTPProviderOTPaaS) + cfg.OTP.OTPaaS.Secret = "" + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_OTPAAS_SECRET"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("otpaas zero timeout", func(t *testing.T) { + cfg := fixtureConfigWithProvider(OTPProviderOTPaaS) + cfg.OTP.OTPaaS.Timeout = 0 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_OTPAAS_TIMEOUT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("otpaas non-positive timeout", func(t *testing.T) { + cfg := fixtureConfigWithProvider(OTPProviderOTPaaS) + cfg.OTP.OTPaaS.Timeout = -1 + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_OTPAAS_TIMEOUT"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("valid mock config", func(t *testing.T) { + cfg := fixtureConfigWithProvider(OTPProviderMock) + + err := cfg.Validate() + + if err != nil { + t.Fatalf("want nil; got %v", err) + } + }) + + t.Run("mock missing allowed emails", func(t *testing.T) { + cfg := fixtureConfigWithProvider(OTPProviderMock) + cfg.OTP.Mock.AllowedEmails = nil + + err := cfg.Validate() + + if err == nil { + t.Fatalf("want error; got nil") + } + if want := "TW_MOCK_ALLOWED_EMAILS"; !strings.Contains(err.Error(), want) { + t.Fatalf("want err containing %q; got %q", want, err.Error()) + } + }) + + t.Run("multiple errors are joined", func(t *testing.T) { + cfg := defaultFixtureConfig() + cfg.Environment = "bad" + cfg.AllowedEmailDomains = nil + cfg.Server.Port = 0 + cfg.OTP.OTPaaS.Host = "" + + err := cfg.Validate() + if err == nil { + t.Fatal("want err; got nil") + } + + msg := err.Error() + for _, fragment := range []string{ + "TW_ENV", + "TW_ALLOWED_EMAIL_DOMAINS", + "TW_SERVER_PORT", + "TW_OTPAAS_HOST", + } { + if !strings.Contains(msg, fragment) { + t.Errorf("want err to contain %q; got %q", fragment, msg) + } + } + }) +}