diff --git a/go.mod b/go.mod index 9b4a6180..3f8b0413 100644 --- a/go.mod +++ b/go.mod @@ -2,22 +2,34 @@ module github.com/codeshelldev/secured-signal-api go 1.25.5 -require github.com/codeshelldev/gotl v0.0.12 - -require go.uber.org/zap v1.27.1 // indirect +require ( + github.com/codeshelldev/gotl/pkg/configutils v0.0.4 + github.com/codeshelldev/gotl/pkg/docker v0.0.2 + github.com/codeshelldev/gotl/pkg/jsonutils v0.0.2 + github.com/codeshelldev/gotl/pkg/logger v0.0.3 + github.com/codeshelldev/gotl/pkg/pretty v0.0.8 + github.com/codeshelldev/gotl/pkg/query v0.0.3 + github.com/codeshelldev/gotl/pkg/request v0.0.3 + github.com/codeshelldev/gotl/pkg/stringutils v0.0.3 + github.com/codeshelldev/gotl/pkg/templating v0.0.3 + github.com/knadh/koanf/parsers/yaml v1.1.0 + go.uber.org/zap v1.27.1 + golang.org/x/time v0.14.0 +) require ( + github.com/clipperhouse/uax29/v2 v2.2.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/knadh/koanf/maps v0.1.2 // indirect - github.com/knadh/koanf/parsers/yaml v1.1.0 github.com/knadh/koanf/providers/confmap v1.0.0 // indirect github.com/knadh/koanf/providers/env/v2 v2.0.0 // indirect github.com/knadh/koanf/providers/file v1.2.1 // indirect github.com/knadh/koanf/v2 v2.3.0 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/sys v0.39.0 // indirect + golang.org/x/sys v0.40.0 // indirect ) diff --git a/go.sum b/go.sum index 645685e1..2cea6122 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,25 @@ -github.com/codeshelldev/gotl v0.0.9 h1:cdLA6XzPt+f4RIW24Yx3dqBbRAq5JO0obzuwhaOgsEo= -github.com/codeshelldev/gotl v0.0.9/go.mod h1:rDkJma6eQSRfCr7ieX9/esn3/uAWNzjHfpjlr9j6FFk= -github.com/codeshelldev/gotl v0.0.12 h1:VM3W6hiEfPgK+cCLT70S004tYAdhQWXD72FtFqCF+2Q= -github.com/codeshelldev/gotl v0.0.12/go.mod h1:rDkJma6eQSRfCr7ieX9/esn3/uAWNzjHfpjlr9j6FFk= +github.com/clipperhouse/uax29/v2 v2.2.0 h1:ChwIKnQN3kcZteTXMgb1wztSgaU+ZemkgWdohwgs8tY= +github.com/clipperhouse/uax29/v2 v2.2.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= +github.com/codeshelldev/gotl/pkg/configutils v0.0.4 h1:AHCCP+8FzFKP8lMtytWo98x0583bCPI7wykbw6z9pTs= +github.com/codeshelldev/gotl/pkg/configutils v0.0.4/go.mod h1:WoWMBB8+84ePRnI2m+kbq1Rw8F/9iCWLHkVBsun3Qjc= +github.com/codeshelldev/gotl/pkg/docker v0.0.2 h1:kpseReocEBoSzWe/tOhUrIrOYeAR/inw3EF2/d+N078= +github.com/codeshelldev/gotl/pkg/docker v0.0.2/go.mod h1:odNnlRw4aO1n2hSkDZIaiuSXIoFoVeatmXtF64Yd33U= +github.com/codeshelldev/gotl/pkg/jsonutils v0.0.2 h1:ERsjkaWVrsyUZoEunCEeNYDXhuaIvoSetB8e/zI4Tqo= +github.com/codeshelldev/gotl/pkg/jsonutils v0.0.2/go.mod h1:oxgKaAoMu6iYVHfgR7AhkK22xbYx4K0KCkyVEfYVoWs= +github.com/codeshelldev/gotl/pkg/logger v0.0.3 h1:M99fsH7SiIFS4jNRCNtu3BJNYdcPD+LbqJ7l5aBQeJ8= +github.com/codeshelldev/gotl/pkg/logger v0.0.3/go.mod h1:pL/I7KYxbGHhyedallZlCkBvoalv9gAWNEYVXbF9BoM= +github.com/codeshelldev/gotl/pkg/pretty v0.0.6 h1:b+1k4Sm6Do7TqUlOFQ5YjybyHJMXYs72GYYBJpSL5JQ= +github.com/codeshelldev/gotl/pkg/pretty v0.0.6/go.mod h1:2Gk6UBrtkIME2RCSNytS/RJ5lHXYL/MDx0rYRpknobM= +github.com/codeshelldev/gotl/pkg/pretty v0.0.8 h1:buLobwNqZRlYGnfyFLi7A7z2m7362Wm9k5Y+Tv0tMsI= +github.com/codeshelldev/gotl/pkg/pretty v0.0.8/go.mod h1:2Gk6UBrtkIME2RCSNytS/RJ5lHXYL/MDx0rYRpknobM= +github.com/codeshelldev/gotl/pkg/query v0.0.3 h1:Zy8k0R5HcJS00OMPRHybgFEiwMg7ceLyv6bA0G7NOfs= +github.com/codeshelldev/gotl/pkg/query v0.0.3/go.mod h1:kKaPOKXluIid3qeS7xzrmfq3NxIa8/PhKYHM6GRbwJw= +github.com/codeshelldev/gotl/pkg/request v0.0.3 h1:maRPHu366NARow8/m1Q8Cw1EU1Uy0pDIn1vlAsOatKM= +github.com/codeshelldev/gotl/pkg/request v0.0.3/go.mod h1:vCXIZ2n/XxvEVInBQv9eIh0kQ2353V+WymL8kZ9yrOU= +github.com/codeshelldev/gotl/pkg/stringutils v0.0.3 h1:7k/HMnX7me8Kchm41I/M6dp3wXI0XORI3oyS87O0Viw= +github.com/codeshelldev/gotl/pkg/stringutils v0.0.3/go.mod h1:/dWlzYoTj23LmpFs+Bpal4tfUDbOVeApIgkLv9gTgUE= +github.com/codeshelldev/gotl/pkg/templating v0.0.3 h1:PAz6VN8yGBuZIdR/sDM+TmW1OFykl+I7/Zwa07uMgYA= +github.com/codeshelldev/gotl/pkg/templating v0.0.3/go.mod h1:D+wxgsPSiq9HShEzv1mhYAjGJyasWgPoIu+nRk4TPqY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= @@ -24,6 +42,8 @@ github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= @@ -40,8 +60,10 @@ go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internals/config/loader.go b/internals/config/loader.go index dbdfaab9..ab81a2a9 100644 --- a/internals/config/loader.go +++ b/internals/config/loader.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/codeshelldev/gotl/pkg/configutils" - log "github.com/codeshelldev/gotl/pkg/logger" + "github.com/codeshelldev/gotl/pkg/logger" "github.com/codeshelldev/gotl/pkg/stringutils" "github.com/codeshelldev/secured-signal-api/internals/config/structure" @@ -63,13 +63,13 @@ func Load() { InitTokens() - log.Info("Finished Loading Configuration") + logger.Info("Finished Loading Configuration") } func Log() { - log.Dev("Loaded Config:", mainConf.Layer.Get("")) - log.Dev("Loaded Token Configs:", tokenConf.Layer.Get("")) - log.Dev("Parsed Configs: ", ENV) + logger.Dev("Loaded Config:", mainConf.Layer.Get("")) + logger.Dev("Loaded Token Configs:", tokenConf.Layer.Get("")) + logger.Dev("Parsed Configs: ", ENV) } func Clear() { @@ -102,7 +102,7 @@ func Normalize(id string, config *configutils.Config, path string, structure any old, ok := data.(map[string]any) if !ok { - log.Warn("Could not load `"+path+"`") + logger.Warn("Could not load `"+path+"`") return } @@ -111,7 +111,10 @@ func Normalize(id string, config *configutils.Config, path string, structure any tmpConf.Load(old, "") // Apply transforms to the new config - tmpConf.ApplyTransformFuncs(id, structure, "", transformFuncs) + tmpConf.ApplyTransformFuncs(id, structure, "", configutils.TransformOptions{ + Transforms: transformFuncs, + OnUse: onUseFuncs, + }) // Lowercase actual config LowercaseKeys(config) @@ -124,7 +127,7 @@ func Normalize(id string, config *configutils.Config, path string, structure any func InitReload() { reload := func(path string) { - log.Debug(path, " changed, reloading...") + logger.Debug(path, " changed, reloading...") Load() Log() } @@ -145,16 +148,16 @@ func InitConfig() { } func LoadDefaults() { - log.Debug("Loading defaults ", ENV.DEFAULTS_PATH) + logger.Debug("Loading defaults ", ENV.DEFAULTS_PATH) _, err := defaultsConf.LoadFile(ENV.DEFAULTS_PATH, yaml.Parser()) if err != nil { - log.Warn("Could not Load Defaults", ENV.DEFAULTS_PATH) + logger.Warn("Could not Load Defaults", ENV.DEFAULTS_PATH) } } func LoadConfig() { - log.Debug("Loading Config ", ENV.CONFIG_PATH) + logger.Debug("Loading Config ", ENV.CONFIG_PATH) _, err := userConf.LoadFile(ENV.CONFIG_PATH, yaml.Parser()) if err != nil { @@ -166,7 +169,7 @@ func LoadConfig() { return } - log.Error("Could not Load Config ", ENV.CONFIG_PATH, ": ", err.Error()) + logger.Error("Could not Load Config ", ENV.CONFIG_PATH, ": ", err.Error()) } } diff --git a/internals/config/parser.go b/internals/config/parser.go index 0222cdde..c63886d6 100644 --- a/internals/config/parser.go +++ b/internals/config/parser.go @@ -1,7 +1,11 @@ package config import ( + "fmt" "strings" + + "github.com/codeshelldev/gotl/pkg/configutils" + "github.com/codeshelldev/gotl/pkg/pretty" ) var transformFuncs = map[string]func(string, any) (string, any) { @@ -21,4 +25,54 @@ func lowercaseTransform(key string, value any) (string, any) { func uppercaseTransform(key string, value any) (string, any) { return strings.ToUpper(key), value +} + +var onUseFuncs = map[string]func(source string, target configutils.TransformTarget) { + "deprecated": func(source string, target configutils.TransformTarget) { + box := pretty.NewAutoBox() + box.MinWidth = 50 + box.PaddingX = 2 + box.PaddingY = 1 + + box.Border.Style = pretty.BorderStyle{ + Color: pretty.Basic(pretty.Yellow), + } + + box.AddBlock(pretty.Block{ + Align: pretty.AlignCenter, + Style: pretty.Style{}, + Segments: []pretty.Segment{ + pretty.TextBlockSegment{ + Text: "🚨 Deprecation 🚨", + Style: pretty.Style{ + Bold: true, + Foreground: pretty.Basic(pretty.Yellow), + }, + }, + pretty.InlineSegment{}, + pretty.TextBlockSegment{ + Text: "Please refrain from using", + }, + pretty.InlineSegment{}, + pretty.TextBlockSegment{ + Text: "`" + source + "`", + Style: pretty.Style{ + Italic: true, + Bold: true, + Background: pretty.Basic(pretty.Red), + }, + }, + pretty.InlineSegment{}, + pretty.InlineSegment{ + Items: []pretty.Inline{ + pretty.Span{ + Text: "as it has been marked as deprecated", + }, + }, + }, + }, + }) + + fmt.Println(box.Render()) + }, } \ No newline at end of file diff --git a/internals/config/structure/structure.go b/internals/config/structure/structure.go index ba2ce910..2c2c127e 100644 --- a/internals/config/structure/structure.go +++ b/internals/config/structure/structure.go @@ -11,29 +11,29 @@ type ENV struct { } type CONFIG struct { + NAME string `koanf:"name"` SERVICE SERVICE `koanf:"service"` API API `koanf:"api"` - //TODO: deprecate overrides for tkconfigs - SETTINGS SETTINGS `koanf:"settings" token>aliases:"overrides"` + SETTINGS SETTINGS `koanf:"settings"` } type SERVICE struct { + HOSTNAMES []string `koanf:"hostnames" env>aliases:".hostnames"` PORT string `koanf:"port" env>aliases:".port"` LOG_LEVEL string `koanf:"loglevel" env>aliases:".loglevel"` } type API struct { URL string `koanf:"url" env>aliases:".apiurl"` - //TODO: deprecate .token for tkconfigs - TOKENS []string `koanf:"tokens" env>aliases:".apitokens,.apitoken" token>aliases:".tokens,.token" aliases:"token"` + TOKENS []string `koanf:"tokens" env>aliases:".apitokens,.apitoken" aliases:"token"` } type SETTINGS struct { - ACCESS ACCESS_SETTINGS `koanf:"access"` - MESSAGE MESSAGE_SETTINGS `koanf:"message"` + ACCESS ACCESS `koanf:"access"` + MESSAGE MESSAGE `koanf:"message"` } -type MESSAGE_SETTINGS struct { +type MESSAGE struct { VARIABLES map[string]any `koanf:"variables" childtransform:"upper"` FIELD_MAPPINGS map[string][]FieldMapping `koanf:"fieldmappings" childtransform:"default"` TEMPLATE string `koanf:"template"` @@ -44,12 +44,21 @@ type FieldMapping struct { Score int `koanf:"score"` } -type ACCESS_SETTINGS struct { +type ACCESS struct { ENDPOINTS []string `koanf:"endpoints"` - FIELD_POLICIES map[string]FieldPolicy `koanf:"fieldpolicies" childtransform:"default"` + FIELD_POLICIES map[string][]FieldPolicy `koanf:"fieldpolicies" childtransform:"default"` + RATE_LIMITING RateLimiting `koanf:"ratelimiting"` + IP_FILTER []string `koanf:"ipfilter"` + TRUSTED_IPS []string `koanf:"trustedips"` + TRUSTED_PROXIES []string `koanf:"trustedproxies"` } type FieldPolicy struct { Value any `koanf:"value"` Action string `koanf:"action"` +} + +type RateLimiting struct { + Limit int `koanf:"limit"` + Period string `koanf:"period"` } \ No newline at end of file diff --git a/internals/config/tokens.go b/internals/config/tokens.go index 21c5d242..6fb44532 100644 --- a/internals/config/tokens.go +++ b/internals/config/tokens.go @@ -1,21 +1,24 @@ package config import ( + "path/filepath" + "reflect" "strconv" + "strings" "github.com/codeshelldev/gotl/pkg/configutils" - log "github.com/codeshelldev/gotl/pkg/logger" + "github.com/codeshelldev/gotl/pkg/logger" "github.com/codeshelldev/secured-signal-api/internals/config/structure" "github.com/knadh/koanf/parsers/yaml" ) func LoadTokens() { - log.Debug("Loading Configs in ", ENV.TOKENS_DIR) + logger.Debug("Loading Configs in ", ENV.TOKENS_DIR) - err := tokenConf.LoadDir("tokenconfigs", ENV.TOKENS_DIR, ".yml", yaml.Parser(), func(c *configutils.Config, s string) {}) + err := tokenConf.LoadDir("tokenconfigs", ENV.TOKENS_DIR, ".yml", yaml.Parser(), setTokenConfigName) if err != nil { - log.Error("Could not Load Configs in ", ENV.TOKENS_DIR, ": ", err.Error()) + logger.Error("Could not Load Configs in ", ENV.TOKENS_DIR, ": ", err.Error()) } tokenConf.TemplateConfig() @@ -57,9 +60,9 @@ func InitTokens() { } if len(apiTokens) <= 0 { - log.Warn("No API Tokens provided this is NOT recommended") + logger.Warn("No API Tokens provided this is NOT recommended") - log.Info("Disabling Security Features due to incomplete Congfiguration") + logger.Info("Disabling Security Features due to incomplete Congfiguration") ENV.INSECURE = true @@ -69,7 +72,7 @@ func InitTokens() { } if len(apiTokens) > 0 { - log.Debug("Registered " + strconv.Itoa(len(apiTokens)) + " Tokens") + logger.Debug("Registered " + strconv.Itoa(len(apiTokens)) + " Tokens") } } @@ -84,3 +87,39 @@ func parseTokenConfigs(configArray []structure.CONFIG) map[string]structure.CONF return configs } + +func getSchemeTagByPointer(config any, tag string, fieldPointer any) string { + v := reflect.ValueOf(config) + if v.Kind() == reflect.Pointer { + v = v.Elem() + } + + fieldValue := reflect.ValueOf(fieldPointer).Elem() + + for i := 0; i < v.NumField(); i++ { + if v.Field(i).Addr().Interface() == fieldValue.Addr().Interface() { + field := v.Type().Field(i) + + return field.Tag.Get(tag) + } + } + + return "" +} + +func setTokenConfigName(config *configutils.Config, p string) { + schema := structure.CONFIG{ + NAME: "", + } + + nameField := getSchemeTagByPointer(&schema, "koanf", &schema.NAME) + + filename := filepath.Base(p) + filenameWithoutExt := strings.TrimSuffix(filename, filepath.Ext(filename)) + + name := config.Layer.String(nameField) + + if strings.TrimSpace(name) == "" { + config.Layer.Set(nameField, filenameWithoutExt) + } +} \ No newline at end of file diff --git a/internals/proxy/middlewares/auth.go b/internals/proxy/middlewares/auth.go index 9d0f2725..91afc979 100644 --- a/internals/proxy/middlewares/auth.go +++ b/internals/proxy/middlewares/auth.go @@ -1,14 +1,16 @@ package middlewares import ( - "context" "encoding/base64" + "errors" "maps" "net/http" + "net/url" "slices" "strings" - log "github.com/codeshelldev/gotl/pkg/logger" + "github.com/codeshelldev/gotl/pkg/logger" + "github.com/codeshelldev/gotl/pkg/request" "github.com/codeshelldev/secured-signal-api/internals/config" ) @@ -17,101 +19,268 @@ var Auth Middleware = Middleware{ Use: authHandler, } -func authHandler(next http.Handler) http.Handler { - tokenKeys := maps.Keys(config.ENV.CONFIGS) - tokens := slices.Collect(tokenKeys) +const tokenKey contextKey = "token" +const isAuthKey contextKey = "isAuthenticated" - if tokens == nil { - tokens = []string{} - } +func authHandler(next http.Handler) http.Handler { + var authChain = NewAuthChain(). + Use(BearerAuth). + Use(BasicAuth). + Use(BodyAuth). + Use(QueryAuth). + Use(PathAuth) return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if len(tokens) <= 0 { + tokenKeys := maps.Keys(config.ENV.CONFIGS) + tokens := slices.Collect(tokenKeys) + + if tokens == nil { + tokens = []string{} + } + + if config.ENV.INSECURE || len(tokens) <= 0 { next.ServeHTTP(w, req) return } - authHeader := req.Header.Get("Authorization") + token, _ := authChain.Eval(w, req, tokens) - authQuery := req.URL.Query().Get("@authorization") + if token == "" { + onUnauthorized(w) - var authType authType = None + req = setContext(req, isAuthKey, false) + } else { + req = setContext(req, isAuthKey, true) + req = setContext(req, tokenKey, token) + } + + next.ServeHTTP(w, req) + }) +} - var authToken string +var InternalAuthRequirement Middleware = Middleware{ + Name: "_Auth_Requirement", + Use: authRequirementHandler, +} - success := false +func authRequirementHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + isAuthenticated := getContext[bool](req, isAuthKey) - if authHeader != "" { - authBody := strings.Split(authHeader, " ") + if !isAuthenticated { + return + } - authType = getAuthType(authBody[0]) - authToken = authBody[1] + next.ServeHTTP(w, req) + }) +} - switch authType { - case Bearer: - if isValidToken(tokens, authToken) { - success = true - } - case Basic: - basicAuthBody, err := base64.StdEncoding.DecodeString(authToken) +type AuthMethod struct { + Name string + Authenticate func(w http.ResponseWriter, req *http.Request, tokens []string) (string, error) +} - if err != nil { - log.Error("Could not decode Basic Auth Payload: ", err.Error()) - } +var BearerAuth = AuthMethod{ + Name: "Bearer", + Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (string, error) { + header := req.Header.Get("Authorization") - basicAuth := string(basicAuthBody) - basicAuthParts := strings.Split(basicAuth, ":") + headerParts := strings.SplitN(header, " ", 2) - user := "api" - authToken = basicAuthParts[1] + if len(headerParts) != 2 { + return "", nil + } - if basicAuthParts[0] == user && isValidToken(tokens, authToken) { - success = true - } + if strings.ToLower(headerParts[0]) == "bearer" { + if isValidToken(tokens, headerParts[1]) { + return headerParts[1], nil } - } else if authQuery != "" { - authType = Query + return "", errors.New("invalid Bearer token") + } - authToken = strings.TrimSpace(authQuery) + return "", nil + }, +} - if isValidToken(tokens, authToken) { - success = true +var BasicAuth = AuthMethod{ + Name: "Basic", + Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (string, error) { + header := req.Header.Get("Authorization") - modifiedQuery := req.URL.Query() + if strings.TrimSpace(header) == "" { + return "", nil + } - modifiedQuery.Del("@authorization") + headerParts := strings.SplitN(header, " ", 2) - req.URL.RawQuery = modifiedQuery.Encode() + if len(headerParts) != 2 { + return "", nil + } + + if strings.ToLower(headerParts[0]) == "basic" { + base64Bytes, err := base64.StdEncoding.DecodeString(headerParts[1]) + + if err != nil { + logger.Error("Could not decode Basic auth payload: ", err.Error()) + return "", errors.New("invalid base64 in Basic auth") + } + + parts := strings.SplitN(string(base64Bytes), ":", 2) + + if len(parts) != 2 { + return "", errors.New("Basic auth must be user:password") + } + + user, password := parts[0], parts[1] + + if strings.ToLower(user) == "api" && isValidToken(tokens, password) { + return password, nil } + + return "", errors.New("invalid user:password") } - if !success { - w.Header().Set("WWW-Authenticate", "Basic realm=\"Login Required\", Bearer realm=\"Access Token Required\"") + return "", nil + }, +} - log.Warn("User failed ", string(authType), " Auth") - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return +var BodyAuth = AuthMethod{ + Name: "Body", + Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (string, error) { + const authField = "auth" + + body, err := request.GetReqBody(req) + + if err != nil { + return "", nil } - ctx := context.WithValue(req.Context(), tokenKey, authToken) - req = req.WithContext(ctx) + body.Write(req) - next.ServeHTTP(w, req) - }) + if body.Empty { + return "", nil + } + + value, exists := body.Data[authField] + + if !exists { + return "", nil + } + + auth, ok := value.(string) + + if !ok { + return "", nil + } + + if isValidToken(tokens, auth) { + delete(body.Data, authField) + + body.Write(req) + + return auth, nil + } + + return "", errors.New("invalid Body token") + }, } -func getAuthType(str string) authType { - switch str { - case "Bearer": - return Bearer - case "Basic": - return Basic - default: - return None - } +var QueryAuth = AuthMethod{ + Name: "Query", + Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (string, error) { + const authQuery = "auth" + + auth := req.URL.Query().Get("@" + authQuery) + + if strings.TrimSpace(auth) == "" { + return "", nil + } + + if isValidToken(tokens, auth) { + query := req.URL.Query() + + query.Del("@" + authQuery) + + req.URL.RawQuery = query.Encode() + + return auth, nil + } + + return "", errors.New("invalid Query token") + }, +} + +var PathAuth = AuthMethod{ + Name: "Path", + Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (string, error) { + parts := strings.Split(req.URL.Path, "/") + + if len(parts) == 0 { + return "", nil + } + + unescaped, err := url.PathUnescape(parts[1]) + + if err != nil { + return "", nil + } + + auth, exists := strings.CutPrefix(unescaped, "auth=") + + if !exists { + return "", nil + } + + if isValidToken(tokens, auth) { + return auth, nil + } + + return "", errors.New("invalid Path token") + }, +} + +func onUnauthorized(w http.ResponseWriter) { + w.Header().Set("WWW-Authenticate", "Basic realm=\"Login Required\", Bearer realm=\"Access Token Required\"") + + http.Error(w, "Unauthorized", http.StatusUnauthorized) } func isValidToken(tokens []string, match string) bool { return slices.Contains(tokens, match) } + +type AuthChain struct { + methods []AuthMethod +} + +func NewAuthChain() *AuthChain { + return &AuthChain{} +} + +func (chain *AuthChain) Use(method AuthMethod) *AuthChain { + chain.methods = append(chain.methods, method) + + logger.Debug("Registered ", method.Name, " auth") + + return chain +} + +func (chain *AuthChain) Eval(w http.ResponseWriter, req *http.Request, tokens []string) (string, error) { + var err error + var token string + + for _, method := range chain.methods { + token, err = method.Authenticate(w, req, tokens) + + if err != nil { + logger.Warn("Client failed ", method.Name, " auth: ", err.Error()) + } + + if token != "" { + return token, nil + } + } + + return "", err +} \ No newline at end of file diff --git a/internals/proxy/middlewares/clientip.go b/internals/proxy/middlewares/clientip.go new file mode 100644 index 00000000..11126c2a --- /dev/null +++ b/internals/proxy/middlewares/clientip.go @@ -0,0 +1,40 @@ +package middlewares + +import ( + "net" + "net/http" +) + +var InternalClientIP Middleware = Middleware{ + Name: "_Client_IP", + Use: clientIPHandler, +} + +var trustedClientKey contextKey = "isClientTrusted" + +func clientIPHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + + conf := getConfigByReq(req) + + rawTrustedIPs := conf.SETTINGS.ACCESS.TRUSTED_IPS + + if rawTrustedIPs == nil { + rawTrustedIPs = getConfig("").SETTINGS.ACCESS.TRUSTED_IPS + } + + ip := getContext[net.IP](req, clientIPKey) + + trustedIPs := parseIPsAndIPNets(rawTrustedIPs) + trusted := isIPInList(ip, trustedIPs) + + if trusted { + logger.Dev("Connection from trusted Client: ", ip.String()) + } + + req = setContext(req, trustedClientKey, trusted) + + next.ServeHTTP(w, req) + }) +} \ No newline at end of file diff --git a/internals/proxy/middlewares/common.go b/internals/proxy/middlewares/common.go index 6984b471..cd497e20 100644 --- a/internals/proxy/middlewares/common.go +++ b/internals/proxy/middlewares/common.go @@ -1,8 +1,10 @@ package middlewares import ( + "context" "net/http" + "github.com/codeshelldev/gotl/pkg/logger" "github.com/codeshelldev/secured-signal-api/internals/config" "github.com/codeshelldev/secured-signal-api/internals/config/structure" ) @@ -11,23 +13,34 @@ type Context struct { Next http.Handler } -type authType string +type contextKey string -const ( - Bearer authType = "Bearer" - Basic authType = "Basic" - Query authType = "Query" - None authType = "None" -) +func setContext(req *http.Request, key, value any) *http.Request { + ctx := context.WithValue(req.Context(), key, value) + return req.WithContext(ctx) +} -type contextKey string +func getContext[T any](req *http.Request, key any) T { + value, ok := req.Context().Value(key).(T) -const tokenKey contextKey = "token" + if !ok { + var zero T + return zero + } -func getConfigByReq(req *http.Request) *structure.CONFIG { - token := req.Context().Value(tokenKey).(string) + return value +} + +func getLogger(req *http.Request) *logger.Logger { + return getContext[*logger.Logger](req, loggerKey) +} - return getConfig(token) +func getToken(req *http.Request) string { + return getContext[string](req, tokenKey) +} + +func getConfigByReq(req *http.Request) *structure.CONFIG { + return getConfig(getToken(req)) } func getConfig(token string) *structure.CONFIG { diff --git a/internals/proxy/middlewares/endpoints.go b/internals/proxy/middlewares/endpoints.go index bac6130a..c85f2899 100644 --- a/internals/proxy/middlewares/endpoints.go +++ b/internals/proxy/middlewares/endpoints.go @@ -5,8 +5,6 @@ import ( "path" "slices" "strings" - - log "github.com/codeshelldev/gotl/pkg/logger" ) var Endpoints Middleware = Middleware{ @@ -16,6 +14,8 @@ var Endpoints Middleware = Middleware{ func endpointsHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + conf := getConfigByReq(req) endpoints := conf.SETTINGS.ACCESS.ENDPOINTS @@ -26,8 +26,8 @@ func endpointsHandler(next http.Handler) http.Handler { reqPath := req.URL.Path - if isBlocked(reqPath, endpoints) { - log.Warn("User tried to access blocked endpoint: ", reqPath) + if isEndpointBlocked(reqPath, endpoints) { + logger.Warn("Client tried to access blocked endpoint: ", reqPath) http.Error(w, "Forbidden", http.StatusForbidden) return } @@ -58,10 +58,10 @@ func matchesPattern(endpoint, pattern string) bool { return ok } -func isBlocked(endpoint string, endpoints []string) bool { - if len(endpoints) == 0 { - // default: block all - return true +func isEndpointBlocked(endpoint string, endpoints []string) bool { + if len(endpoints) == 0 || endpoints == nil { + // default: allow all + return false } allowed, blocked := getEndpoints(endpoints) @@ -82,16 +82,16 @@ func isBlocked(endpoint string, endpoints []string) bool { return true } - // only allowed endpoints -> block anything not allowed - if len(allowed) > 0 && len(blocked) == 0 { + // allow rules -> default deny + if len(allowed) > 0 { return true } - - // only blocked endpoints -> allow anything not blocked - if len(blocked) > 0 && len(allowed) == 0 { + + // only block rules -> default allow + if len(blocked) > 0 { return false } - // no match -> default: block all + // safety net -> block return true } diff --git a/internals/proxy/middlewares/hostname.go b/internals/proxy/middlewares/hostname.go new file mode 100644 index 00000000..1701b1c9 --- /dev/null +++ b/internals/proxy/middlewares/hostname.go @@ -0,0 +1,46 @@ +package middlewares + +import ( + "net/http" + "net/url" + "slices" +) + +var Hostname Middleware = Middleware{ + Name: "Hostname", + Use: hostnameHandler, +} + +func hostnameHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + + conf := getConfigByReq(req) + + hostnames := conf.SERVICE.HOSTNAMES + + if hostnames == nil { + hostnames = getConfig("").SERVICE.HOSTNAMES + } + + if len(hostnames) > 0 { + URL := getContext[*url.URL](req, originURLKey) + + hostname := URL.Hostname() + + if hostname == "" { + logger.Error("Encountered empty hostname") + http.Error(w, "Bad Request: invalid hostname", http.StatusBadRequest) + return + } + + if !slices.Contains(hostnames, hostname) { + logger.Warn("Client tried using Token with wrong hostname") + onUnauthorized(w) + return + } + } + + next.ServeHTTP(w, req) + }) +} \ No newline at end of file diff --git a/internals/proxy/middlewares/ipfilter.go b/internals/proxy/middlewares/ipfilter.go new file mode 100644 index 00000000..5e118a0f --- /dev/null +++ b/internals/proxy/middlewares/ipfilter.go @@ -0,0 +1,95 @@ +package middlewares + +import ( + "net" + "net/http" + "slices" + "strings" +) + +var IPFilter Middleware = Middleware{ + Name: "IP Filter", + Use: ipFilterHandler, +} + +func ipFilterHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + + conf := getConfigByReq(req) + + ipFilter := conf.SETTINGS.ACCESS.IP_FILTER + + if ipFilter == nil { + ipFilter = getConfig("").SETTINGS.ACCESS.IP_FILTER + } + + ip := getContext[net.IP](req, clientIPKey) + + if isIPBlocked(ip, ipFilter) { + logger.Warn("Client IP is blocked by filter: ", ip.String()) + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + next.ServeHTTP(w, req) + }) +} + +func getIPNets(ipNets []string) ([]string, []string) { + blockedIPNets := []string{} + allowedIPNets := []string{} + + for _, ipNet := range ipNets { + ip, block := strings.CutPrefix(ipNet, "!") + + if block { + blockedIPNets = append(blockedIPNets, ip) + } else { + allowedIPNets = append(allowedIPNets, ip) + } + } + + return allowedIPNets, blockedIPNets +} + +func isIPBlocked(ip net.IP, ipfilter []string) (bool) { + if len(ipfilter) == 0 || ipfilter == nil { + // default: allow all + return false + } + + rawAllowed, rawBlocked := getIPNets(ipfilter) + + allowed := parseIPsAndIPNets(rawAllowed) + blocked := parseIPsAndIPNets(rawBlocked) + + isExplicitlyAllowed := slices.ContainsFunc(allowed, func(try *net.IPNet) bool { + return try.Contains(ip) + }) + isExplicitlyBlocked := slices.ContainsFunc(blocked, func(try *net.IPNet) bool { + return try.Contains(ip) + }) + + // explicit allow > block + if isExplicitlyAllowed { + return false + } + + if isExplicitlyBlocked { + return true + } + + // allow rules -> default deny + if len(allowed) > 0 { + return true + } + + // only block rules -> default allow + if len(blocked) > 0 { + return false + } + + // safety net -> block + return true +} diff --git a/internals/proxy/middlewares/log.go b/internals/proxy/middlewares/log.go index c786535f..116702e2 100644 --- a/internals/proxy/middlewares/log.go +++ b/internals/proxy/middlewares/log.go @@ -1,31 +1,82 @@ package middlewares import ( + "net" "net/http" + "strings" - log "github.com/codeshelldev/gotl/pkg/logger" + "github.com/codeshelldev/gotl/pkg/logger" "github.com/codeshelldev/gotl/pkg/request" + "go.uber.org/zap/zapcore" ) -var Logging Middleware = Middleware{ +var RequestLogger Middleware = Middleware{ Name: "Logging", Use: loggingHandler, } +const loggerKey contextKey = "logger" + func loggingHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if !log.IsDev() { - log.Info(req.Method, " ", req.URL.Path, " ", req.URL.RawQuery) + logger := getLogger(req) + + ip := getContext[net.IP](req, clientIPKey) + + if !logger.IsDev() { + logger.Info(ip.String(), " ", req.Method, " ", req.URL.Path, " ", req.URL.RawQuery) } else { body, _ := request.GetReqBody(req) if body.Data != nil && !body.Empty { - log.Dev(req.Method, " ", req.URL.Path, " ", req.URL.RawQuery, body.Data) + logger.Dev(ip.String(), " ", req.Method, " ", req.URL.Path, " ", req.URL.RawQuery, body.Data) } else { - log.Info(req.Method, " ", req.URL.Path, " ", req.URL.RawQuery) + logger.Info(ip.String(), " ", req.Method, " ", req.URL.Path, " ", req.URL.RawQuery) } } next.ServeHTTP(w, req) }) } + +var InternalMiddlewareLogger Middleware = Middleware{ + Name: "_Middleware_Logger", + Use: middlewareLoggerHandler, +} + +func middlewareLoggerHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conf := getConfigByReq(req) + + logLevel := conf.SERVICE.LOG_LEVEL + + if strings.TrimSpace(logLevel) == "" { + logLevel = getConfig("").SERVICE.LOG_LEVEL + } + + options := logger.DefaultOptions() + options.EncodeCaller = func(caller zapcore.EntryCaller, enc zapcore.PrimitiveArrayEncoder) { + var name string + + if strings.TrimSpace(conf.NAME) != "" { + name = " " + conf.NAME + } + + enc.AppendString(caller.TrimmedPath() + name) + } + + l, err := logger.New(logLevel, options) + + if err != nil { + logger.Error("Could not create Middleware Logger: ", err.Error()) + } + + if l == nil { + l = logger.Get() + } + + req = setContext(req, loggerKey, l) + + next.ServeHTTP(w, req) + }) +} \ No newline at end of file diff --git a/internals/proxy/middlewares/mapping.go b/internals/proxy/middlewares/mapping.go index 0f92c4c7..be751495 100644 --- a/internals/proxy/middlewares/mapping.go +++ b/internals/proxy/middlewares/mapping.go @@ -4,7 +4,6 @@ import ( "net/http" jsonutils "github.com/codeshelldev/gotl/pkg/jsonutils" - log "github.com/codeshelldev/gotl/pkg/logger" request "github.com/codeshelldev/gotl/pkg/request" "github.com/codeshelldev/secured-signal-api/internals/config/structure" ) @@ -16,6 +15,8 @@ var Mapping Middleware = Middleware{ func mappingHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + conf := getConfigByReq(req) variables := conf.SETTINGS.MESSAGE.VARIABLES @@ -32,8 +33,9 @@ func mappingHandler(next http.Handler) http.Handler { body, err := request.GetReqBody(req) if err != nil { - log.Error("Could not get Request Body: ", err.Error()) + logger.Error("Could not get Request Body: ", err.Error()) http.Error(w, "Bad Request: invalid body", http.StatusBadRequest) + return } var modifiedBody bool @@ -65,12 +67,12 @@ func mappingHandler(next http.Handler) http.Handler { err := body.Write(req) if err != nil { - log.Error("Could not write to Request Body: ", err.Error()) - http.Error(w, "Internal Error", http.StatusInternalServerError) + logger.Error("Could not write to Request Body: ", err.Error()) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - log.Debug("Applied Data Aliasing: ", body.Data) + logger.Debug("Applied Data Aliasing: ", body.Data) } next.ServeHTTP(w, req) diff --git a/internals/proxy/middlewares/message.go b/internals/proxy/middlewares/message.go index 8af1a0b0..d5da10af 100644 --- a/internals/proxy/middlewares/message.go +++ b/internals/proxy/middlewares/message.go @@ -2,8 +2,8 @@ package middlewares import ( "net/http" + "strings" - log "github.com/codeshelldev/gotl/pkg/logger" request "github.com/codeshelldev/gotl/pkg/request" ) @@ -14,6 +14,8 @@ var Message Middleware = Middleware{ func messageHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + conf := getConfigByReq(req) variables := conf.SETTINGS.MESSAGE.VARIABLES @@ -23,15 +25,16 @@ func messageHandler(next http.Handler) http.Handler { variables = getConfig("").SETTINGS.MESSAGE.VARIABLES } - if messageTemplate == "" { + if strings.TrimSpace(messageTemplate) == "" { messageTemplate = getConfig("").SETTINGS.MESSAGE.TEMPLATE } body, err := request.GetReqBody(req) if err != nil { - log.Error("Could not get Request Body: ", err.Error()) + logger.Error("Could not get Request Body: ", err.Error()) http.Error(w, "Bad Request: invalid body", http.StatusBadRequest) + return } bodyData := map[string]any{} @@ -47,7 +50,7 @@ func messageHandler(next http.Handler) http.Handler { newData, err := TemplateMessage(messageTemplate, bodyData, headerData, variables) if err != nil { - log.Error("Error Templating Message: ", err.Error()) + logger.Error("Error Templating Message: ", err.Error()) } if newData["message"] != bodyData["message"] && newData["message"] != "" && newData["message"] != nil { @@ -63,12 +66,12 @@ func messageHandler(next http.Handler) http.Handler { err := body.Write(req) if err != nil { - log.Error("Could not write to Request Body: ", err.Error()) - http.Error(w, "Internal Error", http.StatusInternalServerError) + logger.Error("Could not write to Request Body: ", err.Error()) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - log.Debug("Applied Message Templating: ", body.Data) + logger.Debug("Applied Message Templating: ", body.Data) } next.ServeHTTP(w, req) diff --git a/internals/proxy/middlewares/middleware.go b/internals/proxy/middlewares/middleware.go index cc222e35..a6e76861 100644 --- a/internals/proxy/middlewares/middleware.go +++ b/internals/proxy/middlewares/middleware.go @@ -2,6 +2,7 @@ package middlewares import ( "net/http" + "strings" "github.com/codeshelldev/gotl/pkg/logger" ) @@ -22,7 +23,12 @@ func NewChain() *Chain { func (chain *Chain) Use(middleware Middleware) *Chain { chain.middlewares = append(chain.middlewares, middleware) - logger.Debug("Registered ", middleware.Name) + if strings.HasPrefix(middleware.Name, "_") { + logger.Dev("Registered ", middleware.Name, " middleware") + } else { + logger.Debug("Registered ", middleware.Name, " middleware") + } + return chain } diff --git a/internals/proxy/middlewares/policy.go b/internals/proxy/middlewares/policy.go index 2415129b..a35d0e2b 100644 --- a/internals/proxy/middlewares/policy.go +++ b/internals/proxy/middlewares/policy.go @@ -4,8 +4,8 @@ import ( "errors" "net/http" "reflect" + "regexp" - log "github.com/codeshelldev/gotl/pkg/logger" request "github.com/codeshelldev/gotl/pkg/request" "github.com/codeshelldev/secured-signal-api/internals/config/structure" "github.com/codeshelldev/secured-signal-api/utils/requestkeys" @@ -18,6 +18,8 @@ var Policy Middleware = Middleware{ func policyHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + conf := getConfigByReq(req) policies := conf.SETTINGS.ACCESS.FIELD_POLICIES @@ -29,8 +31,9 @@ func policyHandler(next http.Handler) http.Handler { body, err := request.GetReqBody(req) if err != nil { - log.Error("Could not get Request Body: ", err.Error()) + logger.Error("Could not get Request Body: ", err.Error()) http.Error(w, "Bad Request: invalid body", http.StatusBadRequest) + return } if body.Empty { @@ -39,10 +42,10 @@ func policyHandler(next http.Handler) http.Handler { headerData := request.GetReqHeaders(req) - shouldBlock, field := doBlock(body.Data, headerData, policies) + shouldBlock, field := isBlockedByPolicy(body.Data, headerData, policies) if shouldBlock { - log.Warn("User tried to use blocked field: ", field) + logger.Warn("Client tried to use blocked field: ", field) http.Error(w, "Forbidden", http.StatusForbidden) return } @@ -51,20 +54,20 @@ func policyHandler(next http.Handler) http.Handler { }) } -func getPolicies(policies map[string]structure.FieldPolicy) (map[string]structure.FieldPolicy, map[string]structure.FieldPolicy) { - blockedFields := map[string]structure.FieldPolicy{} - allowedFields := map[string]structure.FieldPolicy{} +func getPolicies(policies []structure.FieldPolicy) ([]structure.FieldPolicy, []structure.FieldPolicy) { + blocked := []structure.FieldPolicy{} + allowed := []structure.FieldPolicy{} - for field, policy := range policies { + for _, policy := range policies { switch policy.Action { case "block": - blockedFields[field] = policy + blocked = append(blocked, policy) case "allow": - allowedFields[field] = policy + allowed = append(allowed, policy) } } - return allowedFields, blockedFields + return allowed, blocked } func getField(key string, body map[string]any, headers map[string][]string) (any, error) { @@ -76,34 +79,53 @@ func getField(key string, body map[string]any, headers map[string][]string) (any return value, nil } - return value, errors.New("field not found") + return nil, errors.New("field not found") } -func doPoliciesApply(body map[string]any, headers map[string][]string, policies map[string]structure.FieldPolicy) (bool, string) { - for key, policy := range policies { - value, err := getField(key, body, headers) +func doPoliciesApply(key string, body map[string]any, headers map[string][]string, policies []structure.FieldPolicy) (bool, string) { + value, err := getField(key, body, headers) - if err != nil { - continue - } + if err != nil { + return false, "" + } + for _, policy := range policies { switch asserted := value.(type) { case string: policyValue, ok := policy.Value.(string) + re, err := regexp.Compile(policyValue) + + if err == nil { + if re.MatchString(asserted) { + return true, key + } + continue + } + if ok && asserted == policyValue { return true, key } case int: - policyValue, ok := policy.Value.(int); + policyValue, ok := policy.Value.(int) if ok && asserted == policyValue { return true, key } - case bool: - policyValue, ok := policy.Value.(bool) + case float64: + var policyValue float64 + + // needed for json + switch assertedValue := policy.Value.(type) { + case int: + policyValue = float64(assertedValue) + case float64: + policyValue = assertedValue + default: + continue + } - if ok && asserted == policyValue { + if asserted == policyValue { return true, key } default: @@ -116,38 +138,51 @@ func doPoliciesApply(body map[string]any, headers map[string][]string, policies return false, "" } -func doBlock(body map[string]any, headers map[string][]string, policies map[string]structure.FieldPolicy) (bool, string) { - if len(policies) == 0 { +func isBlockedByPolicy(body map[string]any, headers map[string][]string, policies map[string][]structure.FieldPolicy) (bool, string) { + if len(policies) == 0 || policies == nil { // default: allow all return false, "" } - allowed, blocked := getPolicies(policies) + for field, policy := range policies { + if len(policy) == 0 || policy == nil { + continue + } - var cause string + value, _ := getField(field, body, headers) - isExplicitlyAllowed, cause := doPoliciesApply(body, headers, allowed) - isExplicitlyBlocked, cause := doPoliciesApply(body, headers, blocked) - - // explicit allow > block - if isExplicitlyAllowed { - return false, cause - } - - if isExplicitlyBlocked { - return true, cause - } + if value == nil { + continue + } - // only allow policies -> block anything not allowed - if len(allowed) > 0 && len(blocked) == 0 { - return true, cause - } + allowed, blocked := getPolicies(policy) + + isExplicitlyAllowed, cause := doPoliciesApply(field, body, headers, allowed) + isExplicitlyBlocked, cause := doPoliciesApply(field, body, headers, blocked) - // only block polcicies -> allow anything not blocked - if len(blocked) > 0 && len(allowed) == 0 { - return false, cause + // explicit allow > block + if isExplicitlyAllowed { + return false, cause + } + + if isExplicitlyBlocked { + return true, cause + } + + // allow rules -> default deny + if len(allowed) > 0 { + return true, cause + } + + // only block rules -> default allow + if len(blocked) > 0 { + return false, cause + } + + // safety net -> block + return true, "safety net" } - // no match -> default: block all - return true, cause + // default: allow all + return false, "" } diff --git a/internals/proxy/middlewares/port.go b/internals/proxy/middlewares/port.go new file mode 100644 index 00000000..84d30d9f --- /dev/null +++ b/internals/proxy/middlewares/port.go @@ -0,0 +1,54 @@ +package middlewares + +import ( + "errors" + "net" + "net/http" + "strings" +) + +var Port Middleware = Middleware{ + Name: "Port", + Use: portHandler, +} + +func portHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + + conf := getConfigByReq(req) + + allowedPort := conf.SERVICE.PORT + + if strings.TrimSpace(allowedPort) == "" { + next.ServeHTTP(w, req) + return + } + + port, err := getPort(req) + + if err != nil { + logger.Error("Could not get Port: ", err.Error()) + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + if port != allowedPort { + logger.Warn("Client tried using Token on wrong Port") + onUnauthorized(w) + return + } + }) +} + +func getPort(req *http.Request) (string, error) { + addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr) + + if !ok { + return "", errors.New("no local addr in context") + } + + _, port, err := net.SplitHostPort(addr.String()) + + return port, err +} \ No newline at end of file diff --git a/internals/proxy/middlewares/proxy.go b/internals/proxy/middlewares/proxy.go new file mode 100644 index 00000000..599d403c --- /dev/null +++ b/internals/proxy/middlewares/proxy.go @@ -0,0 +1,166 @@ +package middlewares + +import ( + "errors" + "net" + "net/http" + "net/url" + "strings" +) + +var InternalProxy Middleware = Middleware{ + Name: "_Proxy", + Use: proxyHandler, +} + +const trustedProxyKey contextKey = "isProxyTrusted" +const clientIPKey contextKey = "clientIP" +const originURLKey contextKey = "originURL" + +func proxyHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + + conf := getConfigByReq(req) + + rawTrustedProxies := conf.SETTINGS.ACCESS.TRUSTED_PROXIES + + if rawTrustedProxies == nil { + rawTrustedProxies = getConfig("").SETTINGS.ACCESS.TRUSTED_PROXIES + } + + var trusted bool + var ip net.IP + + host, _, _ := net.SplitHostPort(req.RemoteAddr) + + originUrl := req.Proto + "://" + req.URL.Host + + ip = net.ParseIP(host) + + if len(rawTrustedProxies) != 0 { + trustedProxies := parseIPsAndIPNets(rawTrustedProxies) + + trusted = isIPInList(ip, trustedProxies) + } + + if trusted { + realIP, err := getRealIP(req) + + if err != nil { + logger.Error("Could not get real IP: ", err.Error()) + } + + if realIP != nil { + ip = realIP + } + + XFHost := req.Header.Get("X-Forwarded-Host") + XFProto := req.Header.Get("X-Forwarded-Proto") + XFPort := req.Header.Get("X-Forwarded-Port") + + if XFHost == "" || XFProto == "" || XFPort == "" { + logger.Warn("Missing X-Forwarded-* headers") + } + + originUrl = XFProto + "://" + XFHost + ":" + XFPort + } + + originURL, err := url.Parse(originUrl) + + if err != nil { + logger.Error("Could not parse Url: ", originUrl) + http.Error(w, "Bad Request: invalid Url", http.StatusBadRequest) + return + } + + req = setContext(req, trustedProxyKey, trusted) + req = setContext(req, originURLKey, originURL) + + req = setContext(req, clientIPKey, ip) + + next.ServeHTTP(w, req) + }) +} + +func parseIP(str string) (*net.IPNet, error) { + if !strings.Contains(str, "/") { + ip := net.ParseIP(str) + + if ip == nil { + return nil, errors.New("invalid ip: " + str) + } + + var mask net.IPMask + + if ip.To4() != nil { + mask = net.CIDRMask(32, 32) // IPv4 /32 + } else { + mask = net.CIDRMask(128, 128) // IPv6 /128 + } + + return &net.IPNet{IP: ip, Mask: mask}, nil + } + + ip, network, err := net.ParseCIDR(str) + if err != nil { + return nil, err + } + + if !ip.Equal(network.IP) { + var mask net.IPMask + + if ip.To4() != nil { + mask = net.CIDRMask(32, 32) // IPv4 /32 + } else { + mask = net.CIDRMask(128, 128) // IPv6 /128 + } + + return &net.IPNet{IP: ip, Mask: mask}, nil + } + + return network, nil +} + +func parseIPsAndIPNets(array []string) []*net.IPNet { + ipNets := []*net.IPNet{} + + for _, item := range array { + ipNet, err := parseIP(item) + + if err != nil { + continue + } + + ipNets = append(ipNets, ipNet) + } + + return ipNets +} + +func getRealIP(req *http.Request) (net.IP, error) { + XFF := req.Header.Get("X-Forwarded-For") + + if XFF != "" { + ips := strings.Split(XFF, ",") + + realIP := net.ParseIP(strings.TrimSpace(ips[0])) + + if realIP == nil { + return nil, errors.New("malformed X-Forwarded-For header") + } + + return realIP, nil + } + + return nil, errors.New("no X-Forwarded-For header present") +} + +func isIPInList(ip net.IP, list []*net.IPNet) bool { + for _, net := range list { + if net.Contains(ip) { + return true + } + } + return false +} \ No newline at end of file diff --git a/internals/proxy/middlewares/ratelimit.go b/internals/proxy/middlewares/ratelimit.go new file mode 100644 index 00000000..564f4b4b --- /dev/null +++ b/internals/proxy/middlewares/ratelimit.go @@ -0,0 +1,91 @@ +package middlewares + +import ( + "net/http" + "strings" + "time" + + "golang.org/x/time/rate" +) + +var RateLimit Middleware = Middleware{ + Name: "Rate Limiting", + Use: ratelimitHandler, +} + +type TokenLimiter struct { + limiter *rate.Limiter +} + +func NewTokenLimiter(limit int, period time.Duration) *TokenLimiter { + r := rate.Every(period / time.Duration(limit)) + + return &TokenLimiter{ + limiter: rate.NewLimiter(r, limit), + } +} + +func (t *TokenLimiter) Allow() bool { + return t.limiter.Allow() +} + +var tokenLimiters = map[string]*TokenLimiter{} + +func ratelimitHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + + trusted := getContext[bool](req, trustedClientKey) + + if trusted { + next.ServeHTTP(w, req) + return + } + + conf := getConfigByReq(req) + + rateLimiting := conf.SETTINGS.ACCESS.RATE_LIMITING + + limit := rateLimiting.Limit + + if limit == 0 { + limit = getConfig("").SETTINGS.ACCESS.RATE_LIMITING.Limit + } + + periodStr := rateLimiting.Period + + if strings.TrimSpace(periodStr) == "" { + periodStr = conf.SETTINGS.ACCESS.RATE_LIMITING.Period + } + + if strings.TrimSpace(periodStr) != "" && limit != 0 { + period, err := time.ParseDuration(periodStr) + + if err != nil { + logger.Error("Could not parse Duration: ", err.Error()) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + token := getToken(req) + + tokenLimiter, exists := tokenLimiters[token] + + if !exists { + tokenLimiter = NewTokenLimiter(limit, period) + tokenLimiters[token] = tokenLimiter + } + + if !tokenLimiter.Allow() { + logger.Warn("Token exceeded Rate Limit") + + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + w.Header().Set("Retry-After", "60") + + return + } + } + + next.ServeHTTP(w, req) + }) +} \ No newline at end of file diff --git a/internals/proxy/middlewares/template.go b/internals/proxy/middlewares/template.go index eac13b5a..af7d1a2e 100644 --- a/internals/proxy/middlewares/template.go +++ b/internals/proxy/middlewares/template.go @@ -8,7 +8,6 @@ import ( "strings" jsonutils "github.com/codeshelldev/gotl/pkg/jsonutils" - log "github.com/codeshelldev/gotl/pkg/logger" query "github.com/codeshelldev/gotl/pkg/query" request "github.com/codeshelldev/gotl/pkg/request" templating "github.com/codeshelldev/gotl/pkg/templating" @@ -22,6 +21,8 @@ var Template Middleware = Middleware{ func templateHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + logger := getLogger(req) + conf := getConfigByReq(req) variables := conf.SETTINGS.MESSAGE.VARIABLES @@ -33,8 +34,9 @@ func templateHandler(next http.Handler) http.Handler { body, err := request.GetReqBody(req) if err != nil { - log.Error("Could not get Request Body: ", err.Error()) + logger.Error("Could not get Request Body: ", err.Error()) http.Error(w, "Bad Request: invalid body", http.StatusBadRequest) + return } bodyData := map[string]any{} @@ -49,7 +51,7 @@ func templateHandler(next http.Handler) http.Handler { bodyData, modified, err = TemplateBody(body.Data, headerData, variables) if err != nil { - log.Error("Error Templating JSON: ", err.Error()) + logger.Error("Error Templating JSON: ", err.Error()) } if modified { @@ -63,7 +65,7 @@ func templateHandler(next http.Handler) http.Handler { req.URL.RawQuery, bodyData, modified, err = TemplateQuery(req.URL, bodyData, variables) if err != nil { - log.Error("Error Templating Query: ", err.Error()) + logger.Error("Error Templating Query: ", err.Error()) } if modified { @@ -77,12 +79,12 @@ func templateHandler(next http.Handler) http.Handler { err := body.Write(req) if err != nil { - log.Error("Could not write to Request Body: ", err.Error()) - http.Error(w, "Internal Error", http.StatusInternalServerError) + logger.Error("Could not write to Request Body: ", err.Error()) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - log.Debug("Applied Body Templating: ", body.Data) + logger.Debug("Applied Body Templating: ", body.Data) } if req.URL.Path != "" { @@ -91,11 +93,11 @@ func templateHandler(next http.Handler) http.Handler { req.URL.Path, modified, err = TemplatePath(req.URL, variables) if err != nil { - log.Error("Error Templating Path: ", err.Error()) + logger.Error("Error Templating Path: ", err.Error()) } if modified { - log.Debug("Applied Path Templating: ", req.URL.Path) + logger.Debug("Applied Path Templating: ", req.URL.Path) } } diff --git a/internals/proxy/proxy.go b/internals/proxy/proxy.go index 960930e2..b16335b3 100644 --- a/internals/proxy/proxy.go +++ b/internals/proxy/proxy.go @@ -31,9 +31,17 @@ func Create(targetUrl string) Proxy { func (proxy Proxy) Init() http.Handler { handler := m.NewChain(). - Use(m.Logging). Use(m.Server). Use(m.Auth). + Use(m.InternalMiddlewareLogger). + Use(m.InternalProxy). + Use(m.InternalClientIP). + Use(m.RequestLogger). + Use(m.InternalAuthRequirement). + Use(m.Port). + Use(m.Hostname). + Use(m.IPFilter). + Use(m.RateLimit). Use(m.Template). Use(m.Endpoints). Use(m.Mapping). diff --git a/internals/server/server.go b/internals/server/server.go new file mode 100644 index 00000000..680c0046 --- /dev/null +++ b/internals/server/server.go @@ -0,0 +1,127 @@ +package server + +import ( + "context" + "errors" + "net" + "net/http" + "sort" + "strconv" + "strings" + "sync" + + "github.com/codeshelldev/gotl/pkg/logger" +) + +type Server struct { + Host string + Ports []string + Handler http.Handler + Listeners map[string]*http.Server +} + +func Create(handler http.Handler, host string, ports ...string) *Server { + return &Server{ + Host: host, + Ports: ports, + Handler: handler, + Listeners: map[string]*http.Server{}, + } +} + +func (server *Server) ListenAndServer() { + var wg sync.WaitGroup + stopCh := make(chan struct{}) + + for _, port := range server.Ports { + addr := server.Host + ":" + port + listener, err := net.Listen("tcp", addr) + + if err != nil { + logger.Error("Error listening on ", port, ": ", err.Error()) + continue + } + + srv := &http.Server{ + Addr: server.Host + ":" + port, + Handler: server.Handler, + } + + wg.Add(1) + + go func(s *http.Server, l net.Listener, p string) { + defer wg.Done() + + logger.Debug("Listener on port ", port, " started") + + server.Listeners[port] = s + + err := s.Serve(l) + + if err != nil && err != http.ErrServerClosed { + logger.Error("Listener on port ", port, " exited with ", err.Error()) + } + }(srv, listener, port) + } + + go func() { + wg.Wait() + close(stopCh) + }() + + <- stopCh +} + +func (server *Server) Shutdown(ctx context.Context) error { + var errs []error + + for port, s := range server.Listeners { + logger.Debug("Shutting down listener on ", port) + + err := s.Shutdown(ctx) + + if err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + +func PortsToRangeString(ports []string) string { + if len(ports) == 0 { + return "" + } + + sort.Strings(ports) + + result := []string{} + + end, _ := strconv.Atoi(ports[0]) + start, _ := strconv.Atoi(ports[0]) + + for i := 1; i < len(ports); i++ { + port, _ := strconv.Atoi(ports[i]) + + if port == end + 1 { + end = port + } else { + if start == end { + result = append(result, strconv.Itoa(start)) + } else { + result = append(result, strconv.Itoa(start) + "-" + strconv.Itoa(end)) + } + + start = port + end = port + } + } + + if start == end { + result = append(result, strconv.Itoa(start)) + } else { + result = append(result, strconv.Itoa(start) + "-" + strconv.Itoa(end)) + } + + return strings.Join(result, ",") +} \ No newline at end of file diff --git a/logo/banner.png b/logo/banner.png deleted file mode 100644 index 19705b7e..00000000 Binary files a/logo/banner.png and /dev/null differ diff --git a/main.go b/main.go index 7ba52a94..250604ae 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,14 @@ package main import ( - "net/http" "os" + "slices" + "strings" - log "github.com/codeshelldev/gotl/pkg/logger" + "github.com/codeshelldev/gotl/pkg/logger" config "github.com/codeshelldev/secured-signal-api/internals/config" reverseProxy "github.com/codeshelldev/secured-signal-api/internals/proxy" + httpServer "github.com/codeshelldev/secured-signal-api/internals/server" docker "github.com/codeshelldev/secured-signal-api/utils/docker" ) @@ -15,36 +17,21 @@ var proxy reverseProxy.Proxy func main() { logLevel := os.Getenv("LOG_LEVEL") - log.Init(logLevel) + logger.Init(logLevel) docker.Init() config.Load() - if config.DEFAULT.SERVICE.LOG_LEVEL != log.Level() { - log.Init(config.DEFAULT.SERVICE.LOG_LEVEL) + if config.DEFAULT.SERVICE.LOG_LEVEL != logger.Level() { + logger.Init(config.DEFAULT.SERVICE.LOG_LEVEL) } - log.Info("Initialized Logger with Level of ", log.Level()) - - log.Info(` - - ┌────────────────────────────────────────────────┐ - │  🎄 Happy Holidays! 🎄  │ - │ │ - │ Thank you for using this project and for all  │ - │ the downloads, stars, and support this year.  │ - │ │ - │ Your support truly means a lot — here's to  │ - │ an awesome year ahead! ✨  │ - │ │ - │  - CodeShell  │ - └────────────────────────────────────────────────┘ - `) - - if log.Level() == "dev" { - log.Dev("Welcome back Developer!") - log.Dev("CTRL+S config to Print to Console") + logger.Info("Initialized Logger with Level of ", logger.Level()) + + if logger.Level() == "dev" { + logger.Dev("Welcome back, Developer!") + logger.Dev("CTRL+S config to Print to Console") } config.Log() @@ -53,23 +40,28 @@ func main() { handler := proxy.Init() - log.Info("Initialized Middlewares") + logger.Info("Initialized Middlewares") - addr := "0.0.0.0:" + config.DEFAULT.SERVICE.PORT + ports := []string{} - log.Info("Server Listening on ", addr) + for _, config := range config.ENV.CONFIGS { + port := strings.TrimSpace(config.SERVICE.PORT) - server := &http.Server{ - Addr: addr, - Handler: handler, + if port != "" && !slices.Contains(ports, port) { + ports = append(ports, port) + } } - stop := docker.Run(func() { - err := server.ListenAndServe() + server := httpServer.Create(handler, "0.0.0.0", ports...) - if err != nil && err != http.ErrServerClosed { - log.Fatal("Server error: ", err.Error()) + stop := docker.Run(func() { + if logger.IsDebug() && len(ports) > 1 { + logger.Debug("Server started with ", len(ports), " listeners on ", httpServer.PortsToRangeString(ports)) + } else { + logger.Info("Server listening on ", httpServer.PortsToRangeString(ports)) } + + server.ListenAndServer() }) <-stop diff --git a/utils/docker/docker.go b/utils/docker/docker.go index 21b4f2ff..2f3de3db 100644 --- a/utils/docker/docker.go +++ b/utils/docker/docker.go @@ -2,16 +2,16 @@ package docker import ( "context" - "net/http" "os" "time" "github.com/codeshelldev/gotl/pkg/docker" - log "github.com/codeshelldev/gotl/pkg/logger" + "github.com/codeshelldev/gotl/pkg/logger" + "github.com/codeshelldev/secured-signal-api/internals/server" ) func Init() { - log.Info("Running ", os.Getenv("IMAGE_TAG"), " Image") + logger.Info("Running ", os.Getenv("IMAGE_TAG"), " Image") } func Run(main func()) chan os.Signal { @@ -19,15 +19,15 @@ func Run(main func()) chan os.Signal { } func Exit(code int) { - log.Info("Exiting...") + logger.Info("Exiting...") docker.Exit(code) } -func Shutdown(server *http.Server) { - log.Info("Shutdown signal received") +func Shutdown(server *server.Server) { + logger.Info("Shutdown signal received") - log.Sync() + logger.Sync() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -35,10 +35,10 @@ func Shutdown(server *http.Server) { err := server.Shutdown(ctx) if err != nil { - log.Fatal("Server shutdown failed: ", err.Error()) + logger.Fatal("Server shutdown failed: ", err.Error()) - log.Info("Server exited forcefully") + logger.Info("Server exited forcefully") } else { - log.Info("Server exited gracefully") + logger.Info("Server exited gracefully") } }