-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcors.go
More file actions
165 lines (142 loc) · 4.08 KB
/
cors.go
File metadata and controls
165 lines (142 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
package main
import (
"net/http"
"regexp"
"strings"
)
type corsAllowOrigin int
const (
corsAllowOriginAll corsAllowOrigin = iota
corsAllowOriginSource
corsAllowOriginNull
)
type corsACL struct {
hostRe *regexp.Regexp
pathRe *regexp.Regexp
originRe *regexp.Regexp
allowOrigin corsAllowOrigin
allowMethods string
allowHeaders []string
allowCredentials string
}
// CORSHandler adds "Access-Control-Allow-Origin" header to response if specified Origin is in request
type CORSHandler struct {
http.Handler
acls []corsACL
}
// AddRecord make path accessible from origin
func (ch *CORSHandler) AddRecord(path, origin string, opts map[string]string) error {
if ch.acls == nil {
ch.acls = make([]corsACL, 0)
}
acl := corsACL{
allowOrigin: corsAllowOriginSource,
allowMethods: "*",
pathRe: regexp.MustCompile(path),
originRe: regexp.MustCompile(origin),
}
if hostReStr, have := opts["host_re"]; have {
acl.hostRe = regexp.MustCompile(hostReStr)
}
if methodsStr, have := opts["methods"]; have {
acl.allowMethods = strings.Join(strings.Split(methodsStr, ":"), ", ")
}
if hdrString, have := opts["headers"]; have {
acl.allowHeaders = strings.Split(hdrString, ":")
}
if creds, have := opts["creds"]; have {
acl.allowCredentials = creds
}
logf(nil, logLevelInfo,
"CORS: Adding origin host=%s %#v on %#v (methods %#v)",
acl.hostRe, origin, path, acl.allowMethods)
ch.acls = append(ch.acls, acl)
return nil
}
func (ch *CORSHandler) handlePreflight(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
acrMethod := r.Header.Get("Access-Control-Request-Method")
acrHeaders := r.Header.Get("Access-Control-Request-Headers")
for _, acl := range ch.acls {
if (acl.hostRe != nil && !acl.hostRe.MatchString(r.Host)) ||
!acl.pathRe.MatchString(r.URL.Path) ||
!acl.originRe.MatchString(origin) {
continue
}
varyHeaders := []string{}
// allowed origin
switch acl.allowOrigin {
case corsAllowOriginAll:
w.Header().Add("Access-Control-Allow-Origin", "*")
case corsAllowOriginSource:
w.Header().Add("Access-Control-Allow-Origin", origin)
varyHeaders = append(varyHeaders, "Origin")
case corsAllowOriginNull:
w.Header().Add("Access-Control-Allow-Origin", "null")
}
// allowed methods
if acl.allowMethods == "" {
w.Header().Add("Access-Control-Allow-Methods", acrMethod)
} else {
w.Header().Add("Access-Control-Allow-Methods", acl.allowMethods)
}
// allowed headers
if acrHeaders != "" {
acaHeaders := []string{}
if acl.allowHeaders != nil {
acaHeaders = acl.allowHeaders
} else {
for _, header := range strings.Split(acrHeaders, ",") {
acaHeaders = append(acaHeaders, http.CanonicalHeaderKey(strings.Trim(header, " ")))
}
}
w.Header().Add("Access-Control-Allow-Headers", strings.Join(acaHeaders, ", "))
if len(acaHeaders) >= 1 && acaHeaders[0] != "*" {
varyHeaders = append(varyHeaders, acaHeaders...)
}
}
// allowed credentials
if acl.allowCredentials != "" {
w.Header().Add("Access-Control-Allow-Credentials", acl.allowCredentials)
}
// vary headers
if len(varyHeaders) > 0 {
w.Header().Add("Vary", strings.Join(varyHeaders, ", "))
}
w.WriteHeader(http.StatusNoContent)
return
}
logf(r, logLevelWarning,
"CORS: Could not match origin %#v on %#v, passing to backend",
origin, r.URL.Path)
next := ch.Handler
if next == nil {
next = http.DefaultServeMux
}
next.ServeHTTP(w, r)
}
func (ch *CORSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
next := ch.Handler
if next == nil {
next = http.DefaultServeMux
}
if r.Method == "OPTIONS" &&
r.Header.Get("Access-Control-Request-Method") != "" &&
r.Header.Get("Origin") != "" {
ch.handlePreflight(w, r)
return
}
for _, acl := range ch.acls {
origin := r.Header.Get("Origin")
if origin == "" {
continue
}
if acl.originRe.MatchString(origin) &&
(acl.hostRe == nil || acl.hostRe.MatchString(r.Host)) &&
acl.pathRe.MatchString(r.URL.Path) {
w.Header().Add("Access-Control-Allow-Origin", origin)
break
}
}
next.ServeHTTP(w, r)
}