Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
365 changes: 365 additions & 0 deletions server/internal/config/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
package config

import (
"strings"
"testing"
)

func defaultFixtureConfig() *Config {
return fixtureConfigWithProvider(OTPProviderOTPaaS)
}

func fixtureConfigWithProvider(provider Provider) *Config {
cfg := Default()
cfg.OTP.Provider = provider

switch provider {
case OTPProviderOTPaaS:
cfg.OTP.OTPaaS.ID = "app-id"
cfg.OTP.OTPaaS.Namespace = "app-namespace"
cfg.OTP.OTPaaS.Secret = "secret"
case OTPProviderMock:
cfg.OTP.Mock.AllowedEmails = []string{"whitelisted@example.com"}
}

return cfg
}

func TestConfig_Validate(t *testing.T) {
t.Run("valid environment", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Environment = EnvironmentProduction
cfg.BundleDirectory = t.TempDir()

err := cfg.Validate()

if err != nil {
t.Fatalf("want nil; got %v", err)
}
})

t.Run("invalid environment", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Environment = "staging"

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_ENV"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("invalid provider", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.OTP.Provider = "sms"

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_OTP_PROVIDER"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("empty allowed email domains", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.AllowedEmailDomains = nil

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_ALLOWED_EMAIL_DOMAINS"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("server port zero", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Server.Port = 0

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_SERVER_PORT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("server port exceeds max value", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Server.Port = 65536

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_SERVER_PORT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("server zero read header timeout", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Server.ReadHeaderTimeout = 0

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_SERVER_READ_HEADER_TIMEOUT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("server non-positive read header timeout", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Server.ReadHeaderTimeout = -1

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_SERVER_READ_HEADER_TIMEOUT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("server zero read timeout", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Server.ReadTimeout = 0

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_SERVER_READ_TIMEOUT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("server non-positive read timeout", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Server.ReadTimeout = -1

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_SERVER_READ_TIMEOUT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("server zero write timeout", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Server.WriteTimeout = 0

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_SERVER_WRITE_TIMEOUT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("server non-positive write timeout", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Server.WriteTimeout = -1

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_SERVER_WRITE_TIMEOUT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("server zero idle timeout", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Server.IdleTimeout = 0

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_SERVER_IDLE_TIMEOUT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("server non-positive idle timeout", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Server.IdleTimeout = -1

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_SERVER_IDLE_TIMEOUT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("valid otpaas config", func(t *testing.T) {
cfg := fixtureConfigWithProvider(OTPProviderOTPaaS)

err := cfg.Validate()

if err != nil {
t.Fatalf("want nil; got %v", err)
}
})

t.Run("otpaas missing host", func(t *testing.T) {
cfg := fixtureConfigWithProvider(OTPProviderOTPaaS)
cfg.OTP.OTPaaS.Host = ""

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_OTPAAS_HOST"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("otpaas missing app ID", func(t *testing.T) {
cfg := fixtureConfigWithProvider(OTPProviderOTPaaS)
cfg.OTP.OTPaaS.ID = ""

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_OTPAAS_ID"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("otpaas missing app namespace", func(t *testing.T) {
cfg := fixtureConfigWithProvider(OTPProviderOTPaaS)
cfg.OTP.OTPaaS.Namespace = ""

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_OTPAAS_NAMESPACE"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("otpaas missing secret", func(t *testing.T) {
cfg := fixtureConfigWithProvider(OTPProviderOTPaaS)
cfg.OTP.OTPaaS.Secret = ""

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_OTPAAS_SECRET"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("otpaas zero timeout", func(t *testing.T) {
cfg := fixtureConfigWithProvider(OTPProviderOTPaaS)
cfg.OTP.OTPaaS.Timeout = 0

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_OTPAAS_TIMEOUT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("otpaas non-positive timeout", func(t *testing.T) {
cfg := fixtureConfigWithProvider(OTPProviderOTPaaS)
cfg.OTP.OTPaaS.Timeout = -1

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_OTPAAS_TIMEOUT"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("valid mock config", func(t *testing.T) {
cfg := fixtureConfigWithProvider(OTPProviderMock)

err := cfg.Validate()

if err != nil {
t.Fatalf("want nil; got %v", err)
}
})

t.Run("mock missing allowed emails", func(t *testing.T) {
cfg := fixtureConfigWithProvider(OTPProviderMock)
cfg.OTP.Mock.AllowedEmails = nil

err := cfg.Validate()

if err == nil {
t.Fatalf("want error; got nil")
}
if want := "TW_MOCK_ALLOWED_EMAILS"; !strings.Contains(err.Error(), want) {
t.Fatalf("want err containing %q; got %q", want, err.Error())
}
})

t.Run("multiple errors are joined", func(t *testing.T) {
cfg := defaultFixtureConfig()
cfg.Environment = "bad"
cfg.AllowedEmailDomains = nil
cfg.Server.Port = 0
cfg.OTP.OTPaaS.Host = ""

err := cfg.Validate()
if err == nil {
t.Fatal("want err; got nil")
}

msg := err.Error()
for _, fragment := range []string{
"TW_ENV",
"TW_ALLOWED_EMAIL_DOMAINS",
"TW_SERVER_PORT",
"TW_OTPAAS_HOST",
} {
if !strings.Contains(msg, fragment) {
t.Errorf("want err to contain %q; got %q", fragment, msg)
}
}
})
}