From 85a787b41919ae056906301a911d73c181fe8ef4 Mon Sep 17 00:00:00 2001 From: Alexey Michurin Date: Sat, 6 Jan 2024 06:10:41 +0300 Subject: [PATCH] Minor improvements: modern methods, less panic-prone --- go.mod | 2 +- main.go | 51 ++++++++++++++++++++++++++++++--------------------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/go.mod b/go.mod index 0a5ec84..b39afa1 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/digitalocean/sample-golang -go 1.13 +go 1.18 require github.com/gofrs/uuid v3.3.0+incompatible // indirect diff --git a/main.go b/main.go index fc3cb3d..434fc17 100644 --- a/main.go +++ b/main.go @@ -41,22 +41,28 @@ func main() { http.HandleFunc("/cached", func(w http.ResponseWriter, r *http.Request) { logRequest(r) - maxAgeParams, ok := r.URL.Query()["max-age"] - if ok && len(maxAgeParams) > 0 { - maxAge, _ := strconv.Atoi(maxAgeParams[0]) + query := r.URL.Query() + maxAgeParam := query.Get("max-age") + if len(maxAgeParam) > 0 { + maxAge, _ := strconv.Atoi(maxAgeParam) w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", maxAge)) } - responseHeaderParams, ok := r.URL.Query()["headers"] + responseHeaderParams, ok := query["headers"] if ok { for _, header := range responseHeaderParams { - h := strings.Split(header, ":") - w.Header().Set(h[0], strings.TrimSpace(h[1])) + h, v, ok := strings.Cut(header, ":") + if !ok { + continue + } + w.Header().Set(h, strings.TrimSpace(v)) } } - statusCodeParams, ok := r.URL.Query()["status"] - if ok { - statusCode, _ := strconv.Atoi(statusCodeParams[0]) - w.WriteHeader(statusCode) + statusCodeParam := query.Get("status") + if len(statusCodeParam) > 0 { + statusCode, _ := strconv.Atoi(statusCodeParam) + if statusCode >= 200 && statusCode < 600 { + w.WriteHeader(statusCode) + } } requestID := uuid.Must(uuid.NewV4()) fmt.Fprint(w, requestID.String()) @@ -64,9 +70,9 @@ func main() { http.HandleFunc("/headers", func(w http.ResponseWriter, r *http.Request) { logRequest(r) - keys, ok := r.URL.Query()["key"] - if ok && len(keys) > 0 { - fmt.Fprint(w, r.Header.Get(keys[0])) + key := r.URL.Query().Get("key") + if len(key) > 0 { + fmt.Fprint(w, r.Header.Get(key)) return } headers := []string{} @@ -107,15 +113,18 @@ func main() { port = "80" } - for _, encodedRoute := range strings.Split(os.Getenv("ROUTES"), ",") { - if encodedRoute == "" { - continue + encodedRouteString := os.Getenv("ROUTES") + if encodedRouteString != "" { + for _, encodedRoute := range strings.Split(encodedRouteString, ",") { + path, body, ok := strings.Cut(encodedRoute, "=") + if !ok { + fmt.Printf("Skip routing %q: wrong format", encodedRoute) + continue + } + http.HandleFunc("/"+path, func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprint(w, body) + }) } - pathAndBody := strings.SplitN(encodedRoute, "=", 2) - path, body := pathAndBody[0], pathAndBody[1] - http.HandleFunc("/"+path, func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, body) - }) } bindAddr := fmt.Sprintf(":%s", port)