From e5b7366f1b2489f917e1ff74ac22010f920fde76 Mon Sep 17 00:00:00 2001 From: Jonathan Gaillard Date: Fri, 14 Jun 2019 15:35:25 -0700 Subject: [PATCH] Support wildcard for ExposedHeaders option. Via echoing back all headers in a wrapped response writer since browsers don't currently support the wildcard. Fixes #79 --- cors.go | 61 ++++++++++++++++++++++++- cors_test.go | 123 +++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 154 insertions(+), 30 deletions(-) diff --git a/cors.go b/cors.go index caf330e..eea716c 100644 --- a/cors.go +++ b/cors.go @@ -54,7 +54,8 @@ type Options struct { // Default value is [] but "Origin" is always appended to the list. AllowedHeaders []string // ExposedHeaders indicates which headers are safe to expose to the API of a CORS - // API specification + // API specification. + // If the special "*" value is present in the list, all headers will be allowed. ExposedHeaders []string // MaxAge indicates how long (in seconds) the results of a preflight request // can be cached @@ -194,6 +195,7 @@ func AllowAll() *Cors { }, AllowedHeaders: []string{"*"}, AllowCredentials: false, + ExposedHeaders: []string{"*"}, }) } @@ -216,12 +218,15 @@ func (c *Cors) Handler(h http.Handler) http.Handler { } else { c.logf("Handler: Actual request") c.handleActualRequest(w, r) + w = &ExposeAllRespWriter{w, false} h.ServeHTTP(w, r) } }) } -// HandlerFunc provides Martini compatible handler +// HandlerFunc provides Martini compatible handler. +// Since a handler isn't wrapped using this func, considering using +// ExposeAllRespWriter for wildcard support. func (c *Cors) HandlerFunc(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { c.logf("HandlerFunc: Preflight request") @@ -249,6 +254,7 @@ func (c *Cors) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.Handl } else { c.logf("ServeHTTP: Actual request") c.handleActualRequest(w, r) + w = &ExposeAllRespWriter{w, false} next(w, r) } } @@ -427,3 +433,54 @@ func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool { } return true } + +// ExposeAllRespWriter echos back any headers that are set in the wrapped response writer +// to support the wildcard "*" case for Access-Control-Expose-Headers since +// browsers do not currently have good compatibility. +type ExposeAllRespWriter struct { + http.ResponseWriter + applied bool +} + +func (w *ExposeAllRespWriter) Write(b []byte) (int, error) { + w.setHeaders() + return w.ResponseWriter.Write(b) +} + +func (w *ExposeAllRespWriter) WriteHeader(c int) { + w.setHeaders() + w.ResponseWriter.WriteHeader(c) +} + +func (w *ExposeAllRespWriter) setHeaders() { + if w.applied { + return + } + w.applied = true + + if w.ResponseWriter.Header().Get("Access-Control-Expose-Headers") != "*" { + return + } + + var toExpose []string + for k := range w.ResponseWriter.Header() { + switch k { + case + // CORs headers that could be set when Access-Control-Expose-Headers is set + "Access-Control-Allow-Origin", "Access-Control-Allow-Credentials", "Access-Control-Expose-Headers", + + // already allowed by spec + "Cache-Control", "Content-Language", "Content-Type", "Expires", "Last-Modified", "Pragma": + continue + default: + toExpose = append(toExpose, k) + } + } + + if len(toExpose) == 0 { + w.ResponseWriter.Header().Del("Access-Control-Expose-Headers") + return + } + + w.ResponseWriter.Header().Set("Access-Control-Expose-Headers", strings.Join(toExpose, ", ")) +} diff --git a/cors_test.go b/cors_test.go index 68c12eb..80ec92b 100644 --- a/cors_test.go +++ b/cors_test.go @@ -4,6 +4,7 @@ import ( "net/http" "net/http/httptest" "regexp" + "sort" "strings" "testing" ) @@ -26,12 +27,23 @@ func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]s for _, name := range allHeaders { got := strings.Join(resHeaders[name], ", ") want := expHeaders[name] + got = sortCSV(got) + want = sortCSV(want) if got != want { t.Errorf("Response header %q = %q, want %q", name, got, want) } } } +func sortCSV(s string) string { + ss := strings.Split(s, ",") + for i, s := range ss { + ss[i] = strings.TrimSpace(s) + } + sort.Strings(ss) + return strings.Join(ss, ", ") +} + func assertResponse(t *testing.T, res *httptest.ResponseRecorder, responseCode int) { if responseCode != res.Code { t.Errorf("assertResponse: expected response code to be %d but got %d. ", responseCode, res.Code) @@ -40,11 +52,12 @@ func assertResponse(t *testing.T, res *httptest.ResponseRecorder, responseCode i func TestSpec(t *testing.T) { cases := []struct { - name string - options Options - method string - reqHeaders map[string]string - resHeaders map[string]string + name string + options Options + method string + reqHeaders map[string]string + resHeaders map[string]string + wantHeaders map[string]string }{ { "NoConfig", @@ -52,7 +65,8 @@ func TestSpec(t *testing.T) { // Intentionally left blank. }, "GET", - map[string]string{}, + nil, + nil, map[string]string{ "Vary": "Origin", }, @@ -66,8 +80,9 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://foobar.com", }, + nil, map[string]string{ - "Vary": "Origin", + "Vary": "Origin", "Access-Control-Allow-Origin": "*", }, }, @@ -81,8 +96,9 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://foobar.com", }, + nil, map[string]string{ - "Vary": "Origin", + "Vary": "Origin", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", }, @@ -96,8 +112,9 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://foobar.com", }, + nil, map[string]string{ - "Vary": "Origin", + "Vary": "Origin", "Access-Control-Allow-Origin": "http://foobar.com", }, }, @@ -110,8 +127,9 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://foo.bar.com", }, + nil, map[string]string{ - "Vary": "Origin", + "Vary": "Origin", "Access-Control-Allow-Origin": "http://foo.bar.com", }, }, @@ -124,6 +142,7 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://barbaz.com", }, + nil, map[string]string{ "Vary": "Origin", }, @@ -137,6 +156,7 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://foo.baz.com", }, + nil, map[string]string{ "Vary": "Origin", }, @@ -152,8 +172,9 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://foobar.com", }, + nil, map[string]string{ - "Vary": "Origin", + "Vary": "Origin", "Access-Control-Allow-Origin": "http://foobar.com", }, }, @@ -169,8 +190,9 @@ func TestSpec(t *testing.T) { "Origin": "http://foobar.com", "Authorization": "secret", }, + nil, map[string]string{ - "Vary": "Origin", + "Vary": "Origin", "Access-Control-Allow-Origin": "http://foobar.com", }, }, @@ -186,6 +208,7 @@ func TestSpec(t *testing.T) { "Origin": "http://foobar.com", "Authorization": "not-secret", }, + nil, map[string]string{ "Vary": "Origin", }, @@ -202,8 +225,9 @@ func TestSpec(t *testing.T) { "Origin": "http://example.com/", "Access-Control-Request-Method": "GET", }, + nil, map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "http://example.com/", "Access-Control-Allow-Methods": "GET", "Access-Control-Max-Age": "10", @@ -220,8 +244,9 @@ func TestSpec(t *testing.T) { "Origin": "http://foobar.com", "Access-Control-Request-Method": "PUT", }, + nil, map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "http://foobar.com", "Access-Control-Allow-Methods": "PUT", }, @@ -237,6 +262,7 @@ func TestSpec(t *testing.T) { "Origin": "http://foobar.com", "Access-Control-Request-Method": "PATCH", }, + nil, map[string]string{ "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", }, @@ -253,8 +279,9 @@ func TestSpec(t *testing.T) { "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "X-Header-2, X-HEADER-1", }, + nil, map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "http://foobar.com", "Access-Control-Allow-Methods": "GET", "Access-Control-Allow-Headers": "X-Header-2, X-Header-1", @@ -272,8 +299,9 @@ func TestSpec(t *testing.T) { "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "X-Requested-With", }, + nil, map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "http://foobar.com", "Access-Control-Allow-Methods": "GET", "Access-Control-Allow-Headers": "X-Requested-With", @@ -291,8 +319,9 @@ func TestSpec(t *testing.T) { "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "X-Header-2, X-HEADER-1", }, + nil, map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "http://foobar.com", "Access-Control-Allow-Methods": "GET", "Access-Control-Allow-Headers": "X-Header-2, X-Header-1", @@ -310,6 +339,7 @@ func TestSpec(t *testing.T) { "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "X-Header-3, X-Header-1", }, + nil, map[string]string{ "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", }, @@ -325,8 +355,9 @@ func TestSpec(t *testing.T) { "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "origin", }, + nil, map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "http://foobar.com", "Access-Control-Allow-Methods": "GET", "Access-Control-Allow-Headers": "Origin", @@ -342,12 +373,38 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://foobar.com", }, + nil, map[string]string{ - "Vary": "Origin", + "Vary": "Origin", "Access-Control-Allow-Origin": "http://foobar.com", "Access-Control-Expose-Headers": "X-Header-1, X-Header-2", }, }, + { + "ExposedWildcardHeader", + Options{ + AllowedOrigins: []string{"http://foobar.com"}, + ExposedHeaders: []string{"*"}, + }, + "GET", + map[string]string{ + "Origin": "http://foobar.com", + }, + map[string]string{ + "Custom-Header": "custom Value", + "Etag": "test etag", + + // already supported by spec + "Content-Type": "test content type", + }, + map[string]string{ + "Etag": "test etag", + "Custom-Header": "Test Value", + "Vary": "Origin", + "Access-Control-Allow-Origin": "http://foobar.com", + "Access-Control-Expose-Headers": "Custom-Header, Vary, Etag", + }, + }, { "AllowedCredentials", Options{ @@ -359,8 +416,9 @@ func TestSpec(t *testing.T) { "Origin": "http://foobar.com", "Access-Control-Request-Method": "GET", }, + nil, map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "http://foobar.com", "Access-Control-Allow-Methods": "GET", "Access-Control-Allow-Credentials": "true", @@ -376,8 +434,9 @@ func TestSpec(t *testing.T) { "Origin": "http://foobar.com", "Access-Control-Request-Method": "GET", }, + nil, map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET", }, @@ -388,13 +447,20 @@ func TestSpec(t *testing.T) { AllowedOrigins: []string{"http://foobar.com"}, }, "OPTIONS", - map[string]string{}, - map[string]string{}, + nil, + nil, + nil, }, } for i := range cases { tc := cases[i] t.Run(tc.name, func(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for k, v := range tc.resHeaders { + w.Header().Set(k, v) + } + w.Write([]byte("bar")) + }) s := New(tc.options) req, _ := http.NewRequest(tc.method, "http://example.com/foo", nil) @@ -404,18 +470,19 @@ func TestSpec(t *testing.T) { t.Run("Handler", func(t *testing.T) { res := httptest.NewRecorder() - s.Handler(testHandler).ServeHTTP(res, req) - assertHeaders(t, res.Header(), tc.resHeaders) + s.Handler(h).ServeHTTP(res, req) + assertHeaders(t, res.Header(), tc.wantHeaders) }) t.Run("HandlerFunc", func(t *testing.T) { res := httptest.NewRecorder() s.HandlerFunc(res, req) - assertHeaders(t, res.Header(), tc.resHeaders) + h.ServeHTTP(&ExposeAllRespWriter{ResponseWriter: res}, req) + assertHeaders(t, res.Header(), tc.wantHeaders) }) t.Run("Negroni", func(t *testing.T) { res := httptest.NewRecorder() - s.ServeHTTP(res, req, testHandler) - assertHeaders(t, res.Header(), tc.resHeaders) + s.ServeHTTP(res, req, h) + assertHeaders(t, res.Header(), tc.wantHeaders) }) })