diff --git a/README.md b/README.md index fc7843f3..52650e88 100644 --- a/README.md +++ b/README.md @@ -411,6 +411,12 @@ Closes connection when transmitted data exceeded limit. - `bytes`: number of bytes it should transmit before connection is closed +#### http_request_headers + +Modifies http request headers. This toxic only has effect when the direction equals `upstream`. The most common use case would be modifying the Host header when using reverse proxies. + + - `headers`: a key value map with the headers to set. + ### HTTP API All communication with the Toxiproxy daemon from the client happens through the diff --git a/toxics/http_request_headers.go b/toxics/http_request_headers.go new file mode 100644 index 00000000..8cd1c832 --- /dev/null +++ b/toxics/http_request_headers.go @@ -0,0 +1,56 @@ +package toxics + +import ( + "bufio" + "bytes" + "io" + "net/http" + "strings" + + "github.com/Shopify/toxiproxy/stream" +) + +// HttpToxic modifies requests headers (upstream) for http requests. Not to be used with direction = downstream +type HttpToxic struct { + Headers map[string]string `json:"headers"` +} + +func (t *HttpToxic) modifyRequest(request *http.Request) { + // Add all headers to request. Host is derived from the url if we dont set it explicitly. + for k, v := range t.Headers { + if strings.EqualFold("Host", k) { + request.Host = v + } else { + request.Header.Set(k, v) + } + } +} + +func (t *HttpToxic) Pipe(stub *ToxicStub) { + buffer := bytes.NewBuffer(make([]byte, 0, 32*1024)) + writer := stream.NewChanWriter(stub.Output) + reader := stream.NewChanReader(stub.Input) + reader.SetInterrupt(stub.Interrupt) + for { + tee := io.TeeReader(reader, buffer) + req, err := http.ReadRequest(bufio.NewReader(tee)) + if err == stream.ErrInterrupted { + buffer.WriteTo(writer) + return + } else if err == io.EOF { + stub.Close() + return + } + if err != nil { + buffer.WriteTo(writer) + } else { + t.modifyRequest(req) + req.Write(writer) + } + buffer.Reset() + } +} + +func init() { + Register("http_request_headers", new(HttpToxic)) +} diff --git a/toxics/http_request_headers_test.go b/toxics/http_request_headers_test.go new file mode 100644 index 00000000..87e3e2d3 --- /dev/null +++ b/toxics/http_request_headers_test.go @@ -0,0 +1,79 @@ +package toxics_test + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "strings" + "testing" + + "github.com/Shopify/toxiproxy/toxics" +) + +func echoRequestHeaders(w http.ResponseWriter, r *http.Request) { + headersMap := map[string]string{} + + for k, v := range r.Header { + // headers can contain multiple elements. for the purposes of this test we pick the 1st + headersMap[k] = v[0] + } + + mapAsJson, _ := json.Marshal(headersMap) + w.Write([]byte(mapAsJson)) +} + +func TestToxicAddsHTTPHeaders(t *testing.T) { + http.HandleFunc("/", echoRequestHeaders) + + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Failed to create TCP server", err) + } + + go http.Serve(ln, nil) + defer ln.Close() + + proxy := NewTestProxy("test", ln.Addr().String()) + proxy.Start() + defer proxy.Stop() + + resp, err := http.Get("http://" + proxy.Listen) + if err != nil { + t.Error("Failed to connect to proxy", err) + } + + body, err := ioutil.ReadAll(resp.Body) + + AssertDoesNotContainHeader(t, string(body), "Foo", "Bar") + AssertDoesNotContainHeader(t, string(body), "Lorem", "Ipsum") + + proxy.Toxics.AddToxicJson(ToxicToJson(t, "", "http_request_headers", "upstream", &toxics.HttpToxic{Headers: map[string]string{"Foo": "Bar", "Lorem": "Ipsum"}})) + + resp, err = http.Get("http://" + proxy.Listen) + if err != nil { + t.Error("Failed to connect to proxy", err) + } + + body, err = ioutil.ReadAll(resp.Body) + + AssertContainsHeader(t, string(body), "Foo", "Bar") + AssertContainsHeader(t, string(body), "Lorem", "Ipsum") +} + +func AssertDoesNotContainHeader(t *testing.T, body string, headerKey string, headerValue string) { + containsHeader := strings.Contains(string(body), fmt.Sprintf(`"%s":"%s"`, headerKey, headerValue)) + + if containsHeader { + t.Errorf("Unexpected header found. Header=%s", headerKey) + } +} + +func AssertContainsHeader(t *testing.T, body string, headerKey string, headerValue string) { + containsHeader := strings.Contains(string(body), fmt.Sprintf(`"%s":"%s"`, headerKey, headerValue)) + + if !containsHeader { + t.Errorf("Expected header not found. Header=%s", headerKey) + } +}