diff --git a/responsewriter.go b/responsewriter.go index 539d14a..b442d08 100644 --- a/responsewriter.go +++ b/responsewriter.go @@ -3,6 +3,7 @@ package sigsci import ( "bufio" "fmt" + "io" "net" "net/http" ) @@ -75,6 +76,13 @@ func (w *responseRecorder) Write(b []byte) (int, error) { return w.base.Write(b) } +func (w *responseRecorder) ReadFrom(r io.Reader) (n int64, err error) { + if rf, ok := w.base.(io.ReaderFrom); ok { + return rf.ReadFrom(r) + } + return io.Copy(w.base, r) +} + // Hijack hijacks the connection from the HTTP handler so that it can be used directly (websockets, etc.) // NOTE: This will fail if the wrapped http.responseRecorder is not a http.Hijacker. func (w *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { @@ -106,3 +114,8 @@ func (w *responseRecorderFlusher) Flush() { f.Flush() } } + +// ensure our writers satisfy the intended interfaces +var _ http.Hijacker = (*responseRecorder)(nil) +var _ io.ReaderFrom = (*responseRecorder)(nil) +var _ http.Flusher = (*responseRecorderFlusher)(nil) diff --git a/responsewriter_test.go b/responsewriter_test.go index 97ba218..5672f0d 100644 --- a/responsewriter_test.go +++ b/responsewriter_test.go @@ -2,6 +2,7 @@ package sigsci import ( "fmt" + "io" "io/ioutil" "net/http" "net/http/httptest" @@ -25,6 +26,10 @@ func (w *testResponseRecorder) Write(b []byte) (int, error) { return w.Recorder.Write(b) } +func (w *testResponseRecorder) ReadFrom(r io.Reader) (n int64, err error) { + return io.Copy(w.Recorder, r) +} + // testResponseRecorderFlusher is a httptest.ResponseRecorder with the Flusher interface type testResponseRecorderFlusher struct { Recorder *httptest.ResponseRecorder @@ -42,6 +47,10 @@ func (w *testResponseRecorderFlusher) Write(b []byte) (int, error) { return w.Recorder.Write(b) } +func (w *testResponseRecorderFlusher) ReadFrom(r io.Reader) (n int64, err error) { + return io.Copy(w.Recorder, r) +} + func (w *testResponseRecorderFlusher) Flush() { w.Recorder.Flush() }