Skip to content

Commit

Permalink
Merge pull request #8 from dragonsinth/readfrom
Browse files Browse the repository at this point in the history
responseRecorders should implement io.ReaderFrom
  • Loading branch information
brectanus-sigsci authored Apr 1, 2020
2 parents 7404cbf + c6b5a39 commit 3efb8a4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
13 changes: 13 additions & 0 deletions responsewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sigsci
import (
"bufio"
"fmt"
"io"
"net"
"net/http"
)
Expand Down Expand Up @@ -75,6 +76,13 @@ func (w *responseRecorder) Write(b []byte) (int, error) {
return w.base.Write(b)
}

func (w *responseRecorder) ReadFrom(r io.Reader) (n int64, err error) {
if rf, ok := w.base.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(w.base, r)
}

// Hijack hijacks the connection from the HTTP handler so that it can be used directly (websockets, etc.)
// NOTE: This will fail if the wrapped http.responseRecorder is not a http.Hijacker.
func (w *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
Expand Down Expand Up @@ -106,3 +114,8 @@ func (w *responseRecorderFlusher) Flush() {
f.Flush()
}
}

// ensure our writers satisfy the intended interfaces
var _ http.Hijacker = (*responseRecorder)(nil)
var _ io.ReaderFrom = (*responseRecorder)(nil)
var _ http.Flusher = (*responseRecorderFlusher)(nil)
9 changes: 9 additions & 0 deletions responsewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sigsci

import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand All @@ -25,6 +26,10 @@ func (w *testResponseRecorder) Write(b []byte) (int, error) {
return w.Recorder.Write(b)
}

func (w *testResponseRecorder) ReadFrom(r io.Reader) (n int64, err error) {
return io.Copy(w.Recorder, r)
}

// testResponseRecorderFlusher is a httptest.ResponseRecorder with the Flusher interface
type testResponseRecorderFlusher struct {
Recorder *httptest.ResponseRecorder
Expand All @@ -42,6 +47,10 @@ func (w *testResponseRecorderFlusher) Write(b []byte) (int, error) {
return w.Recorder.Write(b)
}

func (w *testResponseRecorderFlusher) ReadFrom(r io.Reader) (n int64, err error) {
return io.Copy(w.Recorder, r)
}

func (w *testResponseRecorderFlusher) Flush() {
w.Recorder.Flush()
}
Expand Down

0 comments on commit 3efb8a4

Please sign in to comment.