From 4830069343e59b0c40789613cc9306ec248162fb Mon Sep 17 00:00:00 2001 From: Manuel Romei Date: Fri, 4 Feb 2022 22:42:40 +0100 Subject: [PATCH 1/3] feat: first draft of the download and upload feature --- cmd/guac/guac.go | 121 ++++++++++--- filter.go | 5 + go.mod | 1 + go.sum | 6 + intercepted_stream.go | 25 +++ server.go | 40 +++-- stream.go | 5 +- stream_intercepting_filter.go | 314 ++++++++++++++++++++++++++++++++++ stream_intercepting_tunnel.go | 88 ++++++++++ tunnel.go | 55 +++++- ws_server.go | 39 +++-- ws_server_test.go | 2 +- 12 files changed, 643 insertions(+), 58 deletions(-) create mode 100644 filter.go create mode 100644 intercepted_stream.go create mode 100644 stream_intercepting_filter.go create mode 100644 stream_intercepting_tunnel.go diff --git a/cmd/guac/guac.go b/cmd/guac/guac.go index 515b8ef..d392982 100644 --- a/cmd/guac/guac.go +++ b/cmd/guac/guac.go @@ -9,46 +9,110 @@ import ( "net/url" "strconv" + "github.com/gorilla/mux" + "github.com/gorilla/websocket" "github.com/sirupsen/logrus" "github.com/wwt/guac" ) +var tunnels map[string]guac.Tunnel + func main() { logrus.SetLevel(logrus.DebugLevel) - servlet := guac.NewServer(DemoDoConnect) + // servlet := guac.NewServer(DemoDoConnect) wsServer := guac.NewWebsocketServer(DemoDoConnect) sessions := guac.NewMemorySessionStore() wsServer.OnConnect = sessions.Add wsServer.OnDisconnect = sessions.Delete - mux := http.NewServeMux() - mux.Handle("/tunnel", servlet) - mux.Handle("/tunnel/", servlet) - mux.Handle("/websocket-tunnel", wsServer) - mux.HandleFunc("/sessions/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") + tunnels = make(map[string]guac.Tunnel, 0) + + wsServer.OnConnectWs = func(s string, _ *websocket.Conn, _ *http.Request, t guac.Tunnel) { + tunnels[s] = t + } + + wsServer.OnDisconnectWs = func(s string, _ *websocket.Conn, _ *http.Request, _ guac.Tunnel) { + delete(tunnels, s) + } + + m := mux.NewRouter() + // m.Handle("/", servlet) + m.Handle("/websocket-tunnel", wsServer) + + m.HandleFunc("/api/session/tunnels/{tunnel}/streams/{stream}/{file}", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Disposition", "attachment") + t := mux.Vars(r)["tunnel"] + + tunnel, ok := tunnels[t] + if !ok { + w.Write([]byte("KO")) + w.WriteHeader(http.StatusInternalServerError) + return + } + + sit, ok := tunnel.(*guac.StreamInterceptingTunnel) + if !ok { + w.Write([]byte("Not supported")) + w.WriteHeader(http.StatusBadRequest) + return + } + + stream := mux.Vars(r)["stream"] + + streamIndex, err := strconv.Atoi(stream) + if err != nil { + w.Write([]byte("KO integer")) + w.WriteHeader(http.StatusBadRequest) + return + } - sessions.RLock() - defer sessions.RUnlock() + if err := sit.InterceptOutputStream(streamIndex, w); err != nil { + w.Write([]byte("KO Intercepting output stream")) + } + }).Methods("GET") + + m.HandleFunc("/api/session/tunnels/{tunnel}/streams/{stream}/{file}", func(w http.ResponseWriter, r *http.Request) { + // w.Header().Set("Content-Type", "application/json") + t := mux.Vars(r)["tunnel"] + tunnel, ok := tunnels[t] + if !ok { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("KO")) + return + } - type ConnIds struct { - Uuid string `json:"uuid"` - Num int `json:"num"` + sit, ok := tunnel.(*guac.StreamInterceptingTunnel) + if !ok { + w.Write([]byte("Not supported")) + w.WriteHeader(http.StatusBadRequest) + return } - connIds := make([]*ConnIds, len(sessions.ConnIds)) + stream := mux.Vars(r)["stream"] + + streamIndex, err := strconv.Atoi(stream) + if err != nil { + w.Write([]byte("KO integer")) + w.WriteHeader(http.StatusBadRequest) + return + } - i := 0 - for id, num := range sessions.ConnIds { - connIds[i] = &ConnIds{ - Uuid: id, - Num: num, - } + if err := sit.InterceptInputStream(streamIndex, r.Body); err != nil { + w.Write([]byte("KO intercepting input stream")) } + }).Methods("POST") + + m.HandleFunc("/api/session/tunnels", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(connIds); err != nil { + t := []string{} + for tun := range tunnels { + t = append(t, tun) + } + + if err := json.NewEncoder(w).Encode(t); err != nil { logrus.Error(err) } }) @@ -57,7 +121,7 @@ func main() { s := &http.Server{ Addr: "0.0.0.0:4567", - Handler: mux, + Handler: m, ReadTimeout: guac.SocketTimeout, WriteTimeout: guac.SocketTimeout, MaxHeaderBytes: 1 << 20, @@ -93,7 +157,11 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { } config.Protocol = query.Get("scheme") - config.Parameters = map[string]string{} + config.Parameters = map[string]string{ + "enable-sftp": "true", + "sftp-hostname": "198.18.251.1", + "sftp-port": "22", + } for k, v := range query { config.Parameters[k] = v[0] } @@ -116,7 +184,11 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { config.AudioMimetypes = []string{"audio/L16", "rate=44100", "channels=2"} logrus.Debug("Connecting to guacd") - addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:4822") + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:4444") + if err != nil { + logrus.Errorln("error while resolving 127.0.0.1") + return nil, err + } conn, err := net.DialTCP("tcp", nil, addr) if err != nil { @@ -136,5 +208,6 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { return nil, err } logrus.Debug("Socket configured") - return guac.NewSimpleTunnel(stream), nil + + return guac.NewStreamInterceptingTunnel(guac.NewSimpleTunnel(stream)), nil } diff --git a/filter.go b/filter.go new file mode 100644 index 0000000..d45e001 --- /dev/null +++ b/filter.go @@ -0,0 +1,5 @@ +package guac + +type Filter interface { + Filter(*Instruction) (*Instruction, error) +} diff --git a/go.mod b/go.mod index 20ab153..f52475b 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ replace github.com/Sirupsen/logrus v1.4.2 => github.com/sirupsen/logrus v1.4.2 require ( github.com/google/uuid v1.1.1 + github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.1 github.com/sirupsen/logrus v1.4.2 ) diff --git a/go.sum b/go.sum index ab4a565..713fc13 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,19 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/intercepted_stream.go b/intercepted_stream.go new file mode 100644 index 0000000..624b5eb --- /dev/null +++ b/intercepted_stream.go @@ -0,0 +1,25 @@ +package guac + +import "io" + +type InterceptedOutputStream struct { + Index string + Stream io.Writer + Error error + Closed chan bool +} + +func NewInterceptedOutputStream(index string, stream io.Writer) *InterceptedOutputStream { + return &InterceptedOutputStream{Index: index, Stream: stream, Closed: make(chan bool, 1)} +} + +type InterceptedInputStream struct { + Index string + Stream io.Reader + Error error + Closed chan bool +} + +func NewInterceptedInputStream(index string, stream io.Reader) *InterceptedInputStream { + return &InterceptedInputStream{Index: index, Stream: stream, Closed: make(chan bool, 1)} +} diff --git a/server.go b/server.go index 69221a1..0a018cd 100644 --- a/server.go +++ b/server.go @@ -2,10 +2,12 @@ package guac import ( "fmt" - logger "github.com/sirupsen/logrus" "io" "net/http" "strings" + + "github.com/gorilla/mux" + logger "github.com/sirupsen/logrus" ) const ( @@ -60,21 +62,27 @@ func (s *Server) sendError(response http.ResponseWriter, guacStatus Status, mess } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - err := s.handleTunnelRequestCore(w, r) - if err == nil { - return - } - guacErr := err.(*ErrGuac) - switch guacErr.Kind { - case ErrClient: - logger.Warn("HTTP tunnel request rejected: ", err.Error()) - s.sendError(w, guacErr.Status, err.Error()) - default: - logger.Error("HTTP tunnel request failed: ", err.Error()) - logger.Debug("Internal error in HTTP tunnel.", err) - s.sendError(w, guacErr.Status, "Internal server error.") - } - return + m := mux.NewRouter() + + m.HandleFunc("/debug", func(rw http.ResponseWriter, r *http.Request) { + w.Write([]byte("CIAOOOOOOO")) + }) + + m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + err := s.handleTunnelRequestCore(w, r) + guacErr := err.(*ErrGuac) + switch guacErr.Kind { + case ErrClient: + logger.Warn("HTTP tunnel request rejected: ", err.Error()) + s.sendError(w, guacErr.Status, err.Error()) + default: + logger.Error("HTTP tunnel request failed: ", err.Error()) + logger.Debug("Internal error in HTTP tunnel.", err) + s.sendError(w, guacErr.Status, "Internal server error.") + } + }) + + m.ServeHTTP(w, r) } func (s *Server) handleTunnelRequestCore(response http.ResponseWriter, request *http.Request) (err error) { diff --git a/stream.go b/stream.go index 3f31622..e8ea12c 100644 --- a/stream.go +++ b/stream.go @@ -2,14 +2,17 @@ package guac import ( "fmt" + "io" "net" "time" "github.com/sirupsen/logrus" ) +var _ io.Writer = (*Stream)(nil) + const ( - SocketTimeout = 15 * time.Second + SocketTimeout = 60 * time.Second MaxGuacMessage = 8192 // TODO is this bytes or runes? ) diff --git a/stream_intercepting_filter.go b/stream_intercepting_filter.go new file mode 100644 index 0000000..ce0ac56 --- /dev/null +++ b/stream_intercepting_filter.go @@ -0,0 +1,314 @@ +package guac + +import ( + "encoding/base64" + "errors" + "io" + "strconv" + "sync" + + "github.com/sirupsen/logrus" +) + +var ( + _ Filter = (*InputStreamInterceptingFilter)(nil) + _ Filter = (*OutputStreamInterceptingFilter)(nil) +) + +// Whether this OutputStreamInterceptingFilter should respond to received +// blobs with "ack" messages on behalf of the client. If false, blobs will +// still be handled by this filter, but empty blobs will be sent to the +// client, forcing the client to respond on its own. +var acknowledgeBlobs bool = true + +type InputStreamInterceptingFilter struct { + tunnel Tunnel + istreamLock sync.Mutex + + streams map[string]*InterceptedInputStream +} + +func NewInputStreamInterceptingFilter(tunnel Tunnel) *InputStreamInterceptingFilter { + streams := make(map[string]*InterceptedInputStream) + return &InputStreamInterceptingFilter{tunnel: tunnel, streams: streams} +} + +func (t *InputStreamInterceptingFilter) sendInstruction(instr *Instruction) (err error) { + w := t.tunnel.AcquireWriter() + defer t.tunnel.ReleaseWriter() + + if _, err = w.Write(instr.Byte()); err != nil { + return err + } + + return nil +} + +func (t *InputStreamInterceptingFilter) getInterceptedInputStream(index string) *InterceptedInputStream { + return t.streams[index] +} + +func (t *InputStreamInterceptingFilter) closeInterceptedStream(index string) { + t.istreamLock.Lock() + if t.streams[index] != nil { + t.streams[index].Closed <- true + } + delete(t.streams, index) + t.istreamLock.Unlock() +} + +func (t *InputStreamInterceptingFilter) CloseAll() { + for k := range t.streams { + t.closeInterceptedStream(k) + } +} + +func (t *InputStreamInterceptingFilter) InterceptStream(index int, stream io.Reader) error { + indexStr := strconv.Itoa(index) + + interceptedInputStream := NewInterceptedInputStream(indexStr, stream) + + logrus.Debug("intercepting input stream", indexStr) + + t.istreamLock.Lock() + t.streams[indexStr] = interceptedInputStream + t.istreamLock.Unlock() + + t.handleInterceptedStream(interceptedInputStream) + + <-interceptedInputStream.Closed + + return interceptedInputStream.Error +} + +func (t *InputStreamInterceptingFilter) sendBlob(index string, blob []byte) { + data := base64.StdEncoding.Strict().EncodeToString(blob) + if err := t.sendInstruction(NewInstruction("blob", index, data)); err != nil { + logrus.Errorf("failed to send base64 blob to stream index %s %v", index, err) + } +} + +func (t *InputStreamInterceptingFilter) sendEnd(index string) { + if err := t.sendInstruction(NewInstruction("end", index)); err != nil { + logrus.Errorf("failed to send end to stream index %s %v", index, err) + } +} + +func (t *InputStreamInterceptingFilter) readNextBlob(stream *InterceptedInputStream) { + blob := make([]byte, 4096) + + if n, err := stream.Stream.Read(blob); err != nil { + if n > 0 { + t.sendBlob(stream.Index, blob[:n]) + return + } + logrus.Errorf("could not read from stream %s: %v", stream.Index, err) + t.sendEnd(stream.Index) + t.closeInterceptedStream(stream.Index) + + return + } + + t.sendBlob(stream.Index, blob) +} + +func (t *InputStreamInterceptingFilter) handleACK(instruction *Instruction) { + if len(instruction.Args) < 3 { + return + } + + index := instruction.Args[0] + + stream := t.getInterceptedInputStream(index) + if stream == nil { + logrus.Warning("empty intercepted input stream on ACK") + return + } + + status := instruction.Args[2] + code := Success + + if status != "0" { + codeInt, err := strconv.Atoi(status) + code = FromGuacamoleStatusCode(codeInt) + + if err != nil { + logrus.Error("failed to translate status code") + code = ServerError + } + + stream.Error = ErrServer.NewError(code.String(), instruction.Args[1]) + t.closeInterceptedStream(stream.Index) + return + } + + t.readNextBlob(stream) +} + +func (t *InputStreamInterceptingFilter) Filter(instruction *Instruction) (*Instruction, error) { + if instruction.Opcode == "ack" { + t.handleACK(instruction) + } + + return instruction, nil +} + +func (t *InputStreamInterceptingFilter) handleInterceptedStream(stream *InterceptedInputStream) { + t.readNextBlob(stream) +} + +type OutputStreamInterceptingFilter struct { + istreamLock sync.Mutex + tunnel Tunnel + streams map[string]*InterceptedOutputStream +} + +func NewOutputStreamInterceptingFilter(tunnel Tunnel) *OutputStreamInterceptingFilter { + streams := make(map[string]*InterceptedOutputStream) + return &OutputStreamInterceptingFilter{tunnel: tunnel, streams: streams} +} + +func (t *OutputStreamInterceptingFilter) sendInstruction(instr *Instruction) error { + w := t.tunnel.AcquireWriter() + if _, err := w.Write(instr.Byte()); err != nil { + return err + } + + t.tunnel.ReleaseWriter() + return nil +} + +func (t *OutputStreamInterceptingFilter) getInterceptedStream(idx string) *InterceptedOutputStream { + return t.streams[idx] +} + +func (t *OutputStreamInterceptingFilter) sendACK(index string, message string, status Status) { + if status != Success { + t.closeInterceptedStream(index) + } + + if err := t.sendInstruction(NewInstruction("ack", index, message, status.String())); err != nil { + logrus.Errorf("unable to send ACK for stream %s", index) + } +} + +func (t *OutputStreamInterceptingFilter) InterceptStream(index int, outStream io.Writer) error { + idxStr := strconv.Itoa(index) + if t.tunnel == nil { + return errors.New("invalid tunnel, it's nil") + } + + interceptedOutputStream := NewInterceptedOutputStream(idxStr, outStream) + + logrus.Debug(idxStr, "is now intercepted by outStream", outStream) + + t.istreamLock.Lock() + t.streams[idxStr] = interceptedOutputStream + t.istreamLock.Unlock() + + t.handleInterceptedStream(interceptedOutputStream) + + <-interceptedOutputStream.Closed + + return interceptedOutputStream.Error +} + +func (t *OutputStreamInterceptingFilter) handleBlob(instruction *Instruction) (*Instruction, error) { + // Verify all required arguments are present + args := instruction.Args + if len(args) < 2 { + return instruction, nil + } + + // Pull associated stream + streamIndex := args[0] + + outputInterceptedStream := t.getInterceptedStream(streamIndex) + if outputInterceptedStream == nil { + return instruction, nil + } + + // Decode blob + data := args[1] + + blob, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, err + } + + if outputInterceptedStream.Stream == nil { + logrus.Error("stream in outputInterceptedStream is nil") + return nil, errors.New("stream in outputInterceptedStream is nil") + } + + if _, err := outputInterceptedStream.Stream.Write(blob); err != nil { + logrus.WithError(err).Error("failed to write to intercepted stream") + return nil, err + } + + // Force client to respond with their own "ack" if we need to + // confirm that they are not falling behind with respect to the + // graphical session + if !acknowledgeBlobs { + acknowledgeBlobs = true + return NewInstruction("blob", streamIndex, ""), nil + } + + t.sendACK(streamIndex, "OK", Success) + + // Instruction was handled purely internally + return nil, nil +} + +func (t *OutputStreamInterceptingFilter) handleEnd(instruction *Instruction) { + args := instruction.Args + if len(args) < 1 { + return + } + + t.closeInterceptedStream(args[0]) +} + +func (t *OutputStreamInterceptingFilter) handleSync(instruction *Instruction) { + acknowledgeBlobs = false +} + +func (t *OutputStreamInterceptingFilter) Filter(instruction *Instruction) (*Instruction, error) { + switch instruction.Opcode { + case "blob": + // When a user cancels the download, the connection abruptly drops + // TODO: find a better design + return t.handleBlob(instruction) + case "end": + t.handleEnd(instruction) + return instruction, nil + case "sync": + t.handleSync(instruction) + return instruction, nil + default: + return instruction, nil + } +} + +func (t *OutputStreamInterceptingFilter) handleInterceptedStream(stream *InterceptedOutputStream) { + t.sendACK(stream.Index, "OK", Success) +} + +func (t *OutputStreamInterceptingFilter) closeInterceptedStream(index string) *InterceptedOutputStream { + interceptedStream := t.streams[index] + if interceptedStream != nil { + interceptedStream.Closed <- true + } + + t.istreamLock.Lock() + delete(t.streams, index) + t.istreamLock.Unlock() + + return interceptedStream +} + +func (t *OutputStreamInterceptingFilter) CloseAllInterceptedStreams() { + for k := range t.streams { + t.closeInterceptedStream(k) + } +} diff --git a/stream_intercepting_tunnel.go b/stream_intercepting_tunnel.go new file mode 100644 index 0000000..c12612f --- /dev/null +++ b/stream_intercepting_tunnel.go @@ -0,0 +1,88 @@ +package guac + +import ( + "io" + + "github.com/sirupsen/logrus" +) + +var _ Tunnel = (*StreamInterceptingTunnel)(nil) + +type StreamInterceptingTunnel struct { + tunnel Tunnel + + outputStreamFilter *OutputStreamInterceptingFilter + inputStreamFilter *InputStreamInterceptingFilter +} + +func NewStreamInterceptingTunnel(tunnel Tunnel) *StreamInterceptingTunnel { + stream := &StreamInterceptingTunnel{tunnel: tunnel} + stream.outputStreamFilter = NewOutputStreamInterceptingFilter(stream) + stream.inputStreamFilter = NewInputStreamInterceptingFilter(stream) + return stream +} + +func (t *StreamInterceptingTunnel) AcquireReader() InstructionReader { + reader := t.tunnel.AcquireReader() + + reader = NewFilteredGuacamoleReader(reader, t.outputStreamFilter) + reader = NewFilteredGuacamoleReader(reader, t.inputStreamFilter) + + return reader +} + +func (t *StreamInterceptingTunnel) ReleaseReader() { + t.tunnel.ReleaseReader() +} + +func (t *StreamInterceptingTunnel) HasQueuedReaderThreads() bool { + return t.tunnel.HasQueuedReaderThreads() +} + +func (t *StreamInterceptingTunnel) AcquireWriter() io.Writer { + return t.tunnel.AcquireWriter() +} + +func (t *StreamInterceptingTunnel) ReleaseWriter() { + t.tunnel.ReleaseWriter() +} + +func (t *StreamInterceptingTunnel) HasQueuedWriterThreads() bool { + return t.tunnel.HasQueuedWriterThreads() +} + +func (t *StreamInterceptingTunnel) GetUUID() string { + return t.tunnel.GetUUID() +} + +func (t *StreamInterceptingTunnel) Close() error { + t.outputStreamFilter.CloseAllInterceptedStreams() + + return t.tunnel.Close() +} + +func (t *StreamInterceptingTunnel) ConnectionID() string { + return t.tunnel.ConnectionID() +} + +func (t *StreamInterceptingTunnel) InterceptOutputStream(idx int, output io.Writer) error { + logrus.Debugf("Intercepting output stream %d of tunnel %s", idx, t.tunnel.ConnectionID()) + + if err := t.outputStreamFilter.InterceptStream(idx, output); err != nil { + return err + } + + logrus.Debugf("Finished intercepting output stream %d of tunnel %s", idx, t.ConnectionID()) + return nil +} + +func (t *StreamInterceptingTunnel) InterceptInputStream(idx int, input io.Reader) error { + logrus.Debugf("Intercepting input stream %d of tunnel %s", idx, t.ConnectionID()) + + if err := t.inputStreamFilter.InterceptStream(idx, input); err != nil { + return err + } + + logrus.Debugf("Finished intercepting input stream %d of tunnel %s", idx, t.ConnectionID()) + return nil +} diff --git a/tunnel.go b/tunnel.go index 5b37ca5..0514ad5 100644 --- a/tunnel.go +++ b/tunnel.go @@ -2,10 +2,17 @@ package guac import ( "fmt" - "github.com/google/uuid" "io" + + "github.com/google/uuid" ) +// Ensure SimpleTunnel implements the Tunnel interface +var _ Tunnel = (*SimpleTunnel)(nil) + +// // Ensure InstructionReader implements the InstructionReader interface +var _ InstructionReader = (*FilteredGuacamoleReader)(nil) + // The Guacamole protocol instruction Opcode reserved for arbitrary // internal use by tunnel implementations. The value of this Opcode is // guaranteed to be the empty string (""). Tunnel implementations may use @@ -116,3 +123,49 @@ func (t *SimpleTunnel) Close() (err error) { func (t *SimpleTunnel) GetUUID() string { return t.uuid.String() } + +type FilteredGuacamoleReader struct { + reader InstructionReader + filter Filter +} + +func NewFilteredGuacamoleReader(reader InstructionReader, filter Filter) *FilteredGuacamoleReader { + return &FilteredGuacamoleReader{reader: reader, filter: filter} +} + +func (r *FilteredGuacamoleReader) Available() bool { + return r.reader.Available() +} + +func (r *FilteredGuacamoleReader) Flush() { + r.reader.Flush() +} + +// ReadOne takes an instruction from the stream and parses it into an Instruction +func (r *FilteredGuacamoleReader) ReadOne() (instruction *Instruction, err error) { + instructionBuffer, err := r.reader.ReadSome() + if err != nil { + return + } + + return Parse(instructionBuffer) +} + +func (r *FilteredGuacamoleReader) ReadSome() ([]byte, error) { + for { + unfilteredInstruction, err := r.ReadOne() + if err != nil { + return nil, err + } + + filteredInstruction, err := r.filter.Filter(unfilteredInstruction) + if err != nil { + return nil, err + } + + // Continue reading and filtering until no instructions are dropped + if filteredInstruction != nil { + return filteredInstruction.Byte(), err + } + } +} diff --git a/ws_server.go b/ws_server.go index 2321a9d..c96f5f0 100644 --- a/ws_server.go +++ b/ws_server.go @@ -2,7 +2,6 @@ package guac import ( "bytes" - "io" "net/http" "github.com/gorilla/websocket" @@ -22,7 +21,7 @@ type WebsocketServer struct { OnDisconnect func(string, *http.Request, Tunnel) // OnConnectWs is an optional callback called when a websocket connects. - OnConnectWs func(string, *websocket.Conn, *http.Request) + OnConnectWs func(string, *websocket.Conn, *http.Request, Tunnel) // OnDisconnectWs is an optional callback called when the websocket disconnects. OnDisconnectWs func(string, *websocket.Conn, *http.Request, Tunnel) } @@ -92,12 +91,9 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.OnConnect(id, r) } if s.OnConnectWs != nil { - s.OnConnectWs(id, ws, r) + s.OnConnectWs(id, ws, r, tunnel) } - writer := tunnel.AcquireWriter() - reader := tunnel.AcquireReader() - if s.OnDisconnect != nil { defer s.OnDisconnect(id, r, tunnel) } @@ -105,11 +101,8 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer s.OnDisconnectWs(id, ws, r, tunnel) } - defer tunnel.ReleaseWriter() - defer tunnel.ReleaseReader() - - go wsToGuacd(ws, writer) - guacdToWs(ws, reader) + go wsToGuacd(ws, tunnel) + guacdToWs(ws, tunnel) } // MessageReader wraps a websocket connection and only permits Reading @@ -118,7 +111,7 @@ type MessageReader interface { ReadMessage() (int, []byte, error) } -func wsToGuacd(ws MessageReader, guacd io.Writer) { +func wsToGuacd(ws MessageReader, tunnel Tunnel) { for { _, data, err := ws.ReadMessage() if err != nil { @@ -130,11 +123,14 @@ func wsToGuacd(ws MessageReader, guacd io.Writer) { // messages starting with the InternalDataOpcode are never sent to guacd continue } - + guacd := tunnel.AcquireWriter() if _, err = guacd.Write(data); err != nil { logrus.Traceln("Failed writing to guacd", err) + tunnel.ReleaseWriter() + return } + tunnel.ReleaseWriter() } } @@ -144,11 +140,20 @@ type MessageWriter interface { WriteMessage(int, []byte) error } -func guacdToWs(ws MessageWriter, guacd InstructionReader) { +func guacdToWs(ws MessageWriter, tunnel Tunnel) { buf := bytes.NewBuffer(make([]byte, 0, MaxGuacMessage*2)) + uuid := NewInstruction(InternalDataOpcode, tunnel.ConnectionID()) + if err := ws.WriteMessage(1, uuid.Byte()); err != nil { + logrus.Traceln("Failed to send uuid to ws", err) + return + } + for { + guacd := tunnel.AcquireReader() ins, err := guacd.ReadSome() + tunnel.ReleaseReader() + if err != nil { logrus.Traceln("Error reading from guacd", err) return @@ -164,8 +169,12 @@ func guacdToWs(ws MessageWriter, guacd InstructionReader) { return } + guacd = tunnel.AcquireReader() + avail := guacd.Available() + tunnel.ReleaseReader() + // if the buffer has more data in it or we've reached the max buffer size, send the data and reset - if !guacd.Available() || buf.Len() >= MaxGuacMessage { + if !avail || buf.Len() >= MaxGuacMessage { if err = ws.WriteMessage(1, buf.Bytes()); err != nil { if err == websocket.ErrCloseSent { return diff --git a/ws_server_test.go b/ws_server_test.go index f51e067..1ef2b90 100644 --- a/ws_server_test.go +++ b/ws_server_test.go @@ -23,7 +23,7 @@ func TestWebsocketServer_guacdToWs(t *testing.T) { conn := &fakeConn{ ToRead: expected, } - guac := NewStream(conn, time.Minute) + guac := NewSimpleTunnel(NewStream(conn, time.Minute)) guacdToWs(msgWriter, guac) From d9c901656abc7f09d14a847e76b6343b72e85710 Mon Sep 17 00:00:00 2001 From: Manuel Romei Date: Thu, 17 Feb 2022 17:40:20 +0100 Subject: [PATCH 2/3] refactor: clean a lot, first draft of a test --- cmd/guac/guac.go | 59 +++--- config.go | 14 +- doc.go | 2 +- input_intercepting_filter.go | 168 ++++++++++++++++ input_intercepting_filter_test.go | 41 ++++ instruction.go | 41 ++++ instruction_test.go | 40 ++++ intercepted_stream.go | 16 +- output_intercepting_filter.go | 171 ++++++++++++++++ stream.go | 7 +- stream_intercepting_filter.go | 314 ------------------------------ stream_intercepting_tunnel.go | 88 --------- stream_test.go | 11 +- tunnel.go | 16 +- tunnel_map.go | 2 +- user_tunnel.go | 47 +++++ 16 files changed, 568 insertions(+), 469 deletions(-) create mode 100644 input_intercepting_filter.go create mode 100644 input_intercepting_filter_test.go create mode 100644 output_intercepting_filter.go delete mode 100644 stream_intercepting_filter.go delete mode 100644 stream_intercepting_tunnel.go create mode 100644 user_tunnel.go diff --git a/cmd/guac/guac.go b/cmd/guac/guac.go index d392982..4065cbd 100644 --- a/cmd/guac/guac.go +++ b/cmd/guac/guac.go @@ -22,24 +22,27 @@ func main() { // servlet := guac.NewServer(DemoDoConnect) wsServer := guac.NewWebsocketServer(DemoDoConnect) + wsServerIntercept := guac.NewWebsocketServer(DemoDoConnectWithIntercept) sessions := guac.NewMemorySessionStore() - wsServer.OnConnect = sessions.Add - wsServer.OnDisconnect = sessions.Delete + wsServerIntercept.OnConnect = sessions.Add + wsServerIntercept.OnDisconnect = sessions.Delete - tunnels = make(map[string]guac.Tunnel, 0) + tunnels = make(map[string]guac.Tunnel) - wsServer.OnConnectWs = func(s string, _ *websocket.Conn, _ *http.Request, t guac.Tunnel) { + wsServerIntercept.OnConnectWs = func(s string, _ *websocket.Conn, _ *http.Request, t guac.Tunnel) { tunnels[s] = t } - wsServer.OnDisconnectWs = func(s string, _ *websocket.Conn, _ *http.Request, _ guac.Tunnel) { + wsServerIntercept.OnDisconnectWs = func(s string, _ *websocket.Conn, _ *http.Request, _ guac.Tunnel) { delete(tunnels, s) } m := mux.NewRouter() + // m.Handle("/", servlet) m.Handle("/websocket-tunnel", wsServer) + m.Handle("/websocket-tunnel-intercept", wsServerIntercept) m.HandleFunc("/api/session/tunnels/{tunnel}/streams/{stream}/{file}", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Disposition", "attachment") @@ -52,7 +55,7 @@ func main() { return } - sit, ok := tunnel.(*guac.StreamInterceptingTunnel) + sit, ok := tunnel.(*guac.UserTunnel) if !ok { w.Write([]byte("Not supported")) w.WriteHeader(http.StatusBadRequest) @@ -61,20 +64,12 @@ func main() { stream := mux.Vars(r)["stream"] - streamIndex, err := strconv.Atoi(stream) - if err != nil { - w.Write([]byte("KO integer")) - w.WriteHeader(http.StatusBadRequest) - return - } - - if err := sit.InterceptOutputStream(streamIndex, w); err != nil { + if err := sit.InterceptOutputStream(stream, w); err != nil { w.Write([]byte("KO Intercepting output stream")) } }).Methods("GET") m.HandleFunc("/api/session/tunnels/{tunnel}/streams/{stream}/{file}", func(w http.ResponseWriter, r *http.Request) { - // w.Header().Set("Content-Type", "application/json") t := mux.Vars(r)["tunnel"] tunnel, ok := tunnels[t] if !ok { @@ -83,7 +78,7 @@ func main() { return } - sit, ok := tunnel.(*guac.StreamInterceptingTunnel) + sit, ok := tunnel.(*guac.UserTunnel) if !ok { w.Write([]byte("Not supported")) w.WriteHeader(http.StatusBadRequest) @@ -92,14 +87,7 @@ func main() { stream := mux.Vars(r)["stream"] - streamIndex, err := strconv.Atoi(stream) - if err != nil { - w.Write([]byte("KO integer")) - w.WriteHeader(http.StatusBadRequest) - return - } - - if err := sit.InterceptInputStream(streamIndex, r.Body); err != nil { + if err := sit.InterceptInputStream(stream, r.Body); err != nil { w.Write([]byte("KO intercepting input stream")) } }).Methods("POST") @@ -133,7 +121,7 @@ func main() { } // DemoDoConnect creates the tunnel to the remote machine (via guacd) -func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { +func DemoDoConnect(request *http.Request) (_ guac.Tunnel, err error) { config := guac.NewGuacamoleConfiguration() var query url.Values @@ -157,16 +145,10 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { } config.Protocol = query.Get("scheme") - config.Parameters = map[string]string{ - "enable-sftp": "true", - "sftp-hostname": "198.18.251.1", - "sftp-port": "22", - } for k, v := range query { config.Parameters[k] = v[0] } - var err error if query.Get("width") != "" { config.OptimalScreenHeight, err = strconv.Atoi(query.Get("width")) if err != nil || config.OptimalScreenHeight == 0 { @@ -184,7 +166,7 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { config.AudioMimetypes = []string{"audio/L16", "rate=44100", "channels=2"} logrus.Debug("Connecting to guacd") - addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:4444") + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:4822") if err != nil { logrus.Errorln("error while resolving 127.0.0.1") return nil, err @@ -202,6 +184,7 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { if request.URL.Query().Get("uuid") != "" { config.ConnectionID = request.URL.Query().Get("uuid") } + logrus.Debugf("Starting handshake with %#v", config) err = stream.Handshake(config) if err != nil { @@ -209,5 +192,15 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { } logrus.Debug("Socket configured") - return guac.NewStreamInterceptingTunnel(guac.NewSimpleTunnel(stream)), nil + return guac.NewSimpleTunnel(stream), nil +} + +// DemoDoConnectWithIntercept showcases a use for intercepting streams +func DemoDoConnectWithIntercept(r *http.Request) (guac.Tunnel, error) { + t, err := DemoDoConnect(r) + if err != nil { + return nil, err + } + + return guac.NewUserTunnel(t), nil } diff --git a/config.go b/config.go index c8af4fd..85f0ac7 100644 --- a/config.go +++ b/config.go @@ -5,22 +5,22 @@ type Config struct { // ConnectionID is used to reconnect to an existing session, otherwise leave blank for a new session. ConnectionID string // Protocol is the protocol of the connection from guacd to the remote (rdp, ssh, etc). - Protocol string + Protocol string // Parameters are used to configure protocol specific options like sla for rdp or terminal color schemes. - Parameters map[string]string + Parameters map[string]string // OptimalScreenWidth is the desired width of the screen - OptimalScreenWidth int + OptimalScreenWidth int // OptimalScreenHeight is the desired height of the screen OptimalScreenHeight int // OptimalResolution is the desired resolution of the screen - OptimalResolution int + OptimalResolution int // AudioMimetypes is an array of the supported audio types - AudioMimetypes []string + AudioMimetypes []string // VideoMimetypes is an array of the supported video types - VideoMimetypes []string + VideoMimetypes []string // ImageMimetypes is an array of the supported image types - ImageMimetypes []string + ImageMimetypes []string } // NewGuacamoleConfiguration returns a Config with sane defaults diff --git a/doc.go b/doc.go index f0dfdc0..4b89bff 100644 --- a/doc.go +++ b/doc.go @@ -1,4 +1,4 @@ /* Package guac implements a HTTP client and a WebSocket client that connects to an Apache Guacamole server. - */ +*/ package guac diff --git a/input_intercepting_filter.go b/input_intercepting_filter.go new file mode 100644 index 0000000..c7cbc01 --- /dev/null +++ b/input_intercepting_filter.go @@ -0,0 +1,168 @@ +package guac + +import ( + "encoding/base64" + "errors" + "io" + "strconv" + "sync" + + "github.com/sirupsen/logrus" +) + +var ( + _ Filter = (*InputInterceptingFilter)(nil) + _ Filter = (*OutputInterceptingFilter)(nil) +) + +// Whether this OutputInterceptingFilter should respond to received +// blobs with "ack" messages on behalf of the client. If false, blobs will +// still be handled by this filter, but empty blobs will be sent to the +// client, forcing the client to respond on its own. +var acknowledgeBlobs bool = true + +type InputInterceptingFilter struct { + tunnel Tunnel + l sync.Mutex + + streams map[string]*InterceptedInputStream +} + +func NewInputInterceptingFilter(tunnel Tunnel) *InputInterceptingFilter { + streams := make(map[string]*InterceptedInputStream) + return &InputInterceptingFilter{tunnel: tunnel, streams: streams} +} + +func (t *InputInterceptingFilter) sendInstruction(instr *Instruction) (err error) { + w := t.tunnel.AcquireWriter() + defer t.tunnel.ReleaseWriter() + + if _, err = w.Write(instr.Byte()); err != nil { + logrus.WithError(err).Error("failed to write instruction") + return err + } + + return nil +} + +func (t *InputInterceptingFilter) getInterceptedInputStream(index string) *InterceptedInputStream { + t.l.Lock() + defer t.l.Unlock() + + return t.streams[index] +} + +func (t *InputInterceptingFilter) closeInterceptedStream(index string, err error) { + t.l.Lock() + defer t.l.Unlock() + + if t.streams[index] != nil { + t.streams[index].done <- err + } + delete(t.streams, index) +} + +func (t *InputInterceptingFilter) CloseAll() { + for k := range t.streams { + t.closeInterceptedStream(k, nil) + } +} + +func (t *InputInterceptingFilter) InterceptStream(index string, stream io.Reader) <-chan error { + signal := make(chan error, 1) + + interceptedInputStream := NewInterceptedInputStream(index, stream, signal) + + t.l.Lock() + t.streams[index] = interceptedInputStream + t.l.Unlock() + + t.handleInterceptedStream(interceptedInputStream) + + return signal +} + +func (t *InputInterceptingFilter) sendBlob(index string, blob []byte) { + data := base64.StdEncoding.Strict().EncodeToString(blob) + if err := t.sendInstruction(NewInstruction("blob", index, data)); err != nil { + logrus.Errorf("failed to send base64 blob to stream index %s %v", index, err) + + t.sendEnd(index) + t.closeInterceptedStream(index, err) + } +} + +func (t *InputInterceptingFilter) sendEnd(index string) { + if err := t.sendInstruction(NewInstruction("end", index)); err != nil { + logrus.Errorf("failed to send end to stream index %s %v", index, err) + } +} + +func (t *InputInterceptingFilter) readNextBlob(stream *InterceptedInputStream) { + blob := make([]byte, 4096) + + if n, err := io.ReadFull(stream.Stream, blob); err != nil { + if n > 0 { + logrus.Debug("there are still some bytes") + t.sendBlob(stream.Index, blob[:n]) + return + } + + if !errors.Is(err, io.EOF) { + logrus.WithError(err).Errorf("could not read from stream %s", stream.Index) + } else { + err = nil + } + + t.sendEnd(stream.Index) + t.closeInterceptedStream(stream.Index, err) + + return + } + + t.sendBlob(stream.Index, blob) +} + +func (t *InputInterceptingFilter) handleACK(instruction *Instruction) { + if len(instruction.Args) < 3 { + return + } + + index := instruction.Args[0] + + stream := t.getInterceptedInputStream(index) + if stream == nil { + logrus.Warning("empty intercepted input stream on ACK") + return + } + + status := instruction.Args[2] + code := Success + + if status != "0" { + codeInt, err := strconv.Atoi(status) + code = FromGuacamoleStatusCode(codeInt) + + if err != nil { + logrus.Error("failed to translate status code") + code = ServerError + } + + t.closeInterceptedStream(stream.Index, ErrServer.NewError(code.String(), instruction.Args[1])) + return + } + + t.readNextBlob(stream) +} + +func (t *InputInterceptingFilter) Filter(instruction *Instruction) (*Instruction, error) { + if instruction.Opcode == "ack" { + t.handleACK(instruction) + } + + return instruction, nil +} + +func (t *InputInterceptingFilter) handleInterceptedStream(stream *InterceptedInputStream) { + t.readNextBlob(stream) +} diff --git a/input_intercepting_filter_test.go b/input_intercepting_filter_test.go new file mode 100644 index 0000000..75a0dd2 --- /dev/null +++ b/input_intercepting_filter_test.go @@ -0,0 +1,41 @@ +package guac + +import ( + "bytes" + "testing" + "time" +) + +func TestInputInterceptingFilter(t *testing.T) { + t.Run("OK", func(t *testing.T) { + conn := &fakeConn{ + ToRead: []byte(""), + } + + f := NewInputInterceptingFilter(NewUserTunnel( + NewSimpleTunnel( + NewStream(conn, time.Minute), + ), + )) + + toInject := []byte("Hello") + + // Hijack stream 1 and inject some data that will need to end up on the wire + finished := f.InterceptStream("1", bytes.NewReader([]byte(toInject))) + + // base64("Hello") = "SGVsbG8=" + if got, want := string(conn.ToWrite), "4.blob,1.1,8.SGVsbG8=;"; got != want { + t.Fatalf("On the wire: %v, want=%v", got, want) + } + + // Simulate an ACK from guacd + f.Filter(NewInstruction("ack", "1", "", "0")) + if err := <-finished; err != nil { + t.Fatal(err) + } + + if got, want := string(conn.ToWrite), "4.blob,1.1,8.SGVsbG8=;3.end,1.1;"; got != want { + t.Fatalf("On the wire: %v, want=%v", got, want) + } + }) +} diff --git a/instruction.go b/instruction.go index d8529dd..96a986e 100644 --- a/instruction.go +++ b/instruction.go @@ -111,3 +111,44 @@ func ReadOne(stream *Stream) (instruction *Instruction, err error) { return Parse(instructionBuffer) } + +// FilteredInstructionReader is a struct that provides a filtered +// InstructionReader and handles instructions through a filter. +type FilteredInstructionReader struct { + InstructionReader + + Filter +} + +func NewFilteredInstructionReader(r InstructionReader, filter Filter) InstructionReader { + return &FilteredInstructionReader{r, filter} +} + +func (r *FilteredInstructionReader) ReadSome() ([]byte, error) { + for { + unfilteredInstruction, err := readOne(r.InstructionReader) + if err != nil { + return nil, err + } + + filteredInstruction, err := r.Filter.Filter(unfilteredInstruction) + if err != nil { + return nil, err + } + + // Continue reading and filtering until no instructions are dropped + if filteredInstruction != nil { + return filteredInstruction.Byte(), err + } + } +} + +// readOne takes an instruction from the stream and parses it into an Instruction +func readOne(r InstructionReader) (instruction *Instruction, err error) { + instructionBuffer, err := r.ReadSome() + if err != nil { + return + } + + return Parse(instructionBuffer) +} diff --git a/instruction_test.go b/instruction_test.go index f8629c4..14a0951 100644 --- a/instruction_test.go +++ b/instruction_test.go @@ -68,3 +68,43 @@ func TestReadOne(t *testing.T) { t.Error("Unexpected", ins.String()) } } + +var _ Filter = (*dropFilter)(nil) + +// dropFilter drops all the instructions defined in drop +type dropFilter struct { + Drop []string +} + +func (f *dropFilter) Filter(i *Instruction) (*Instruction, error) { + for _, v := range f.Drop { + if v == i.Opcode { + return nil, nil + } + } + + return i, nil +} + +func TestFilteredInstructionReader(t *testing.T) { + t.Run("OK", func(t *testing.T) { + f := &dropFilter{Drop: []string{"select"}} + + s := NewStream(&fakeConn{ + ToRead: []byte(`6.select,2.hi,5.hello,4.asdf;6.teston,2.hi,5.hello,4.asdf;`), + }, time.Minute) + + fi := NewFilteredInstructionReader(s, f) + + result, err := fi.ReadSome() + if err != nil { + t.Fatal(err) + } + + if got, want := string(result), "6.teston,2.hi,5.hello,4.asdf;"; got != want { + t.Fatalf("Result=%v, want %v", got, want) + } + }) + + // Won't test malformed input, because that's already tested on Stream +} diff --git a/intercepted_stream.go b/intercepted_stream.go index 624b5eb..bb98001 100644 --- a/intercepted_stream.go +++ b/intercepted_stream.go @@ -5,21 +5,21 @@ import "io" type InterceptedOutputStream struct { Index string Stream io.Writer - Error error - Closed chan bool + + done chan<- error } -func NewInterceptedOutputStream(index string, stream io.Writer) *InterceptedOutputStream { - return &InterceptedOutputStream{Index: index, Stream: stream, Closed: make(chan bool, 1)} +func NewInterceptedOutputStream(index string, stream io.Writer, signal chan<- error) *InterceptedOutputStream { + return &InterceptedOutputStream{Index: index, Stream: stream, done: signal} } type InterceptedInputStream struct { Index string Stream io.Reader - Error error - Closed chan bool + + done chan<- error } -func NewInterceptedInputStream(index string, stream io.Reader) *InterceptedInputStream { - return &InterceptedInputStream{Index: index, Stream: stream, Closed: make(chan bool, 1)} +func NewInterceptedInputStream(index string, stream io.Reader, signal chan<- error) *InterceptedInputStream { + return &InterceptedInputStream{Index: index, Stream: stream, done: signal} } diff --git a/output_intercepting_filter.go b/output_intercepting_filter.go new file mode 100644 index 0000000..14b5c38 --- /dev/null +++ b/output_intercepting_filter.go @@ -0,0 +1,171 @@ +package guac + +import ( + "encoding/base64" + "errors" + "io" + "sync" + + "github.com/sirupsen/logrus" +) + +type OutputInterceptingFilter struct { + l sync.Mutex + tunnel Tunnel + streams map[string]*InterceptedOutputStream +} + +func NewOutputInterceptingFilter(tunnel Tunnel) *OutputInterceptingFilter { + streams := make(map[string]*InterceptedOutputStream) + return &OutputInterceptingFilter{tunnel: tunnel, streams: streams} +} + +func (t *OutputInterceptingFilter) sendInstruction(instr *Instruction) error { + w := t.tunnel.AcquireWriter() + if _, err := w.Write(instr.Byte()); err != nil { + logrus.WithError(err).Error("failed to send instruction") + return err + } + + t.tunnel.ReleaseWriter() + return nil +} + +func (t *OutputInterceptingFilter) getInterceptedStream(idx string) *InterceptedOutputStream { + t.l.Lock() + defer t.l.Unlock() + + return t.streams[idx] +} + +func (t *OutputInterceptingFilter) sendACK(index string, message string, status Status) { + if status != Success { + t.closeInterceptedStream(index, ErrServer.NewError(status.String(), message)) + } + + if err := t.sendInstruction(NewInstruction("ack", index, message, status.String())); err != nil { + logrus.Errorf("unable to send ACK for stream %s", index) + } +} + +func (t *OutputInterceptingFilter) InterceptStream(index string, outStream io.Writer) <-chan error { + signal := make(chan error, 1) + + if t.tunnel == nil { + defer func() { + signal <- errors.New("invalid tunnel") + }() + + return signal + } + + interceptedOutputStream := NewInterceptedOutputStream(index, outStream, signal) + + t.l.Lock() + t.streams[index] = interceptedOutputStream + t.l.Unlock() + + t.handleInterceptedStream(interceptedOutputStream) + + return signal +} + +func (t *OutputInterceptingFilter) handleBlob(instruction *Instruction) (*Instruction, error) { + // Verify all required arguments are present + args := instruction.Args + if len(args) < 2 { + return instruction, nil + } + + // Pull associated stream + streamIndex := args[0] + + outputInterceptedStream := t.getInterceptedStream(streamIndex) + if outputInterceptedStream == nil { + return instruction, nil + } + + // Decode blob + data := args[1] + + blob, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, err + } + + if outputInterceptedStream.Stream == nil { + return nil, errors.New("stream in outputInterceptedStream is nil") + } + + if _, err := outputInterceptedStream.Stream.Write(blob); err != nil { + // User closed the connection, no need to panic, + // Just don't track it anymore and close the stream. + t.closeInterceptedStream(streamIndex, nil) + + logrus.WithError(err).Info("failed to write to intercepted stream: maybe user has closed the connection?") + + // Exit cleanly, we don't need to make the server quit listening. + return nil, nil + } + + // Force client to respond with their own "ack" if we need to + // confirm that they are not falling behind with respect to the + // graphical session + if !acknowledgeBlobs { + acknowledgeBlobs = true + return NewInstruction("blob", streamIndex, ""), nil + } + + t.sendACK(streamIndex, "OK", Success) + + // Instruction was handled purely internally + return nil, nil +} + +func (t *OutputInterceptingFilter) handleEnd(instruction *Instruction) { + args := instruction.Args + if len(args) < 1 { + return + } + + t.closeInterceptedStream(args[0], nil) +} + +func (t *OutputInterceptingFilter) handleSync(instruction *Instruction) { + acknowledgeBlobs = false +} + +func (t *OutputInterceptingFilter) Filter(instruction *Instruction) (*Instruction, error) { + switch instruction.Opcode { + case "blob": + return t.handleBlob(instruction) + case "end": + t.handleEnd(instruction) + case "sync": + t.handleSync(instruction) + } + return instruction, nil +} + +func (t *OutputInterceptingFilter) handleInterceptedStream(stream *InterceptedOutputStream) { + t.sendACK(stream.Index, "OK", Success) +} + +func (t *OutputInterceptingFilter) closeInterceptedStream(index string, err error) *InterceptedOutputStream { + interceptedStream := t.streams[index] + if interceptedStream != nil { + interceptedStream.done <- err + } + + t.l.Lock() + delete(t.streams, index) + t.l.Unlock() + + return interceptedStream +} + +func (t *OutputInterceptingFilter) CloseAllInterceptedStreams() { + for k := range t.streams { + t.closeInterceptedStream(k, nil) + } +} diff --git a/stream.go b/stream.go index e8ea12c..592d1a9 100644 --- a/stream.go +++ b/stream.go @@ -2,20 +2,19 @@ package guac import ( "fmt" - "io" "net" "time" "github.com/sirupsen/logrus" ) -var _ io.Writer = (*Stream)(nil) - const ( - SocketTimeout = 60 * time.Second + SocketTimeout = 120 * time.Second MaxGuacMessage = 8192 // TODO is this bytes or runes? ) +var _ InstructionReader = (*Stream)(nil) + // Stream wraps the connection to Guacamole providing timeouts and reading // a single instruction at a time (since returning partial instructions // would be an error) diff --git a/stream_intercepting_filter.go b/stream_intercepting_filter.go deleted file mode 100644 index ce0ac56..0000000 --- a/stream_intercepting_filter.go +++ /dev/null @@ -1,314 +0,0 @@ -package guac - -import ( - "encoding/base64" - "errors" - "io" - "strconv" - "sync" - - "github.com/sirupsen/logrus" -) - -var ( - _ Filter = (*InputStreamInterceptingFilter)(nil) - _ Filter = (*OutputStreamInterceptingFilter)(nil) -) - -// Whether this OutputStreamInterceptingFilter should respond to received -// blobs with "ack" messages on behalf of the client. If false, blobs will -// still be handled by this filter, but empty blobs will be sent to the -// client, forcing the client to respond on its own. -var acknowledgeBlobs bool = true - -type InputStreamInterceptingFilter struct { - tunnel Tunnel - istreamLock sync.Mutex - - streams map[string]*InterceptedInputStream -} - -func NewInputStreamInterceptingFilter(tunnel Tunnel) *InputStreamInterceptingFilter { - streams := make(map[string]*InterceptedInputStream) - return &InputStreamInterceptingFilter{tunnel: tunnel, streams: streams} -} - -func (t *InputStreamInterceptingFilter) sendInstruction(instr *Instruction) (err error) { - w := t.tunnel.AcquireWriter() - defer t.tunnel.ReleaseWriter() - - if _, err = w.Write(instr.Byte()); err != nil { - return err - } - - return nil -} - -func (t *InputStreamInterceptingFilter) getInterceptedInputStream(index string) *InterceptedInputStream { - return t.streams[index] -} - -func (t *InputStreamInterceptingFilter) closeInterceptedStream(index string) { - t.istreamLock.Lock() - if t.streams[index] != nil { - t.streams[index].Closed <- true - } - delete(t.streams, index) - t.istreamLock.Unlock() -} - -func (t *InputStreamInterceptingFilter) CloseAll() { - for k := range t.streams { - t.closeInterceptedStream(k) - } -} - -func (t *InputStreamInterceptingFilter) InterceptStream(index int, stream io.Reader) error { - indexStr := strconv.Itoa(index) - - interceptedInputStream := NewInterceptedInputStream(indexStr, stream) - - logrus.Debug("intercepting input stream", indexStr) - - t.istreamLock.Lock() - t.streams[indexStr] = interceptedInputStream - t.istreamLock.Unlock() - - t.handleInterceptedStream(interceptedInputStream) - - <-interceptedInputStream.Closed - - return interceptedInputStream.Error -} - -func (t *InputStreamInterceptingFilter) sendBlob(index string, blob []byte) { - data := base64.StdEncoding.Strict().EncodeToString(blob) - if err := t.sendInstruction(NewInstruction("blob", index, data)); err != nil { - logrus.Errorf("failed to send base64 blob to stream index %s %v", index, err) - } -} - -func (t *InputStreamInterceptingFilter) sendEnd(index string) { - if err := t.sendInstruction(NewInstruction("end", index)); err != nil { - logrus.Errorf("failed to send end to stream index %s %v", index, err) - } -} - -func (t *InputStreamInterceptingFilter) readNextBlob(stream *InterceptedInputStream) { - blob := make([]byte, 4096) - - if n, err := stream.Stream.Read(blob); err != nil { - if n > 0 { - t.sendBlob(stream.Index, blob[:n]) - return - } - logrus.Errorf("could not read from stream %s: %v", stream.Index, err) - t.sendEnd(stream.Index) - t.closeInterceptedStream(stream.Index) - - return - } - - t.sendBlob(stream.Index, blob) -} - -func (t *InputStreamInterceptingFilter) handleACK(instruction *Instruction) { - if len(instruction.Args) < 3 { - return - } - - index := instruction.Args[0] - - stream := t.getInterceptedInputStream(index) - if stream == nil { - logrus.Warning("empty intercepted input stream on ACK") - return - } - - status := instruction.Args[2] - code := Success - - if status != "0" { - codeInt, err := strconv.Atoi(status) - code = FromGuacamoleStatusCode(codeInt) - - if err != nil { - logrus.Error("failed to translate status code") - code = ServerError - } - - stream.Error = ErrServer.NewError(code.String(), instruction.Args[1]) - t.closeInterceptedStream(stream.Index) - return - } - - t.readNextBlob(stream) -} - -func (t *InputStreamInterceptingFilter) Filter(instruction *Instruction) (*Instruction, error) { - if instruction.Opcode == "ack" { - t.handleACK(instruction) - } - - return instruction, nil -} - -func (t *InputStreamInterceptingFilter) handleInterceptedStream(stream *InterceptedInputStream) { - t.readNextBlob(stream) -} - -type OutputStreamInterceptingFilter struct { - istreamLock sync.Mutex - tunnel Tunnel - streams map[string]*InterceptedOutputStream -} - -func NewOutputStreamInterceptingFilter(tunnel Tunnel) *OutputStreamInterceptingFilter { - streams := make(map[string]*InterceptedOutputStream) - return &OutputStreamInterceptingFilter{tunnel: tunnel, streams: streams} -} - -func (t *OutputStreamInterceptingFilter) sendInstruction(instr *Instruction) error { - w := t.tunnel.AcquireWriter() - if _, err := w.Write(instr.Byte()); err != nil { - return err - } - - t.tunnel.ReleaseWriter() - return nil -} - -func (t *OutputStreamInterceptingFilter) getInterceptedStream(idx string) *InterceptedOutputStream { - return t.streams[idx] -} - -func (t *OutputStreamInterceptingFilter) sendACK(index string, message string, status Status) { - if status != Success { - t.closeInterceptedStream(index) - } - - if err := t.sendInstruction(NewInstruction("ack", index, message, status.String())); err != nil { - logrus.Errorf("unable to send ACK for stream %s", index) - } -} - -func (t *OutputStreamInterceptingFilter) InterceptStream(index int, outStream io.Writer) error { - idxStr := strconv.Itoa(index) - if t.tunnel == nil { - return errors.New("invalid tunnel, it's nil") - } - - interceptedOutputStream := NewInterceptedOutputStream(idxStr, outStream) - - logrus.Debug(idxStr, "is now intercepted by outStream", outStream) - - t.istreamLock.Lock() - t.streams[idxStr] = interceptedOutputStream - t.istreamLock.Unlock() - - t.handleInterceptedStream(interceptedOutputStream) - - <-interceptedOutputStream.Closed - - return interceptedOutputStream.Error -} - -func (t *OutputStreamInterceptingFilter) handleBlob(instruction *Instruction) (*Instruction, error) { - // Verify all required arguments are present - args := instruction.Args - if len(args) < 2 { - return instruction, nil - } - - // Pull associated stream - streamIndex := args[0] - - outputInterceptedStream := t.getInterceptedStream(streamIndex) - if outputInterceptedStream == nil { - return instruction, nil - } - - // Decode blob - data := args[1] - - blob, err := base64.StdEncoding.DecodeString(data) - if err != nil { - return nil, err - } - - if outputInterceptedStream.Stream == nil { - logrus.Error("stream in outputInterceptedStream is nil") - return nil, errors.New("stream in outputInterceptedStream is nil") - } - - if _, err := outputInterceptedStream.Stream.Write(blob); err != nil { - logrus.WithError(err).Error("failed to write to intercepted stream") - return nil, err - } - - // Force client to respond with their own "ack" if we need to - // confirm that they are not falling behind with respect to the - // graphical session - if !acknowledgeBlobs { - acknowledgeBlobs = true - return NewInstruction("blob", streamIndex, ""), nil - } - - t.sendACK(streamIndex, "OK", Success) - - // Instruction was handled purely internally - return nil, nil -} - -func (t *OutputStreamInterceptingFilter) handleEnd(instruction *Instruction) { - args := instruction.Args - if len(args) < 1 { - return - } - - t.closeInterceptedStream(args[0]) -} - -func (t *OutputStreamInterceptingFilter) handleSync(instruction *Instruction) { - acknowledgeBlobs = false -} - -func (t *OutputStreamInterceptingFilter) Filter(instruction *Instruction) (*Instruction, error) { - switch instruction.Opcode { - case "blob": - // When a user cancels the download, the connection abruptly drops - // TODO: find a better design - return t.handleBlob(instruction) - case "end": - t.handleEnd(instruction) - return instruction, nil - case "sync": - t.handleSync(instruction) - return instruction, nil - default: - return instruction, nil - } -} - -func (t *OutputStreamInterceptingFilter) handleInterceptedStream(stream *InterceptedOutputStream) { - t.sendACK(stream.Index, "OK", Success) -} - -func (t *OutputStreamInterceptingFilter) closeInterceptedStream(index string) *InterceptedOutputStream { - interceptedStream := t.streams[index] - if interceptedStream != nil { - interceptedStream.Closed <- true - } - - t.istreamLock.Lock() - delete(t.streams, index) - t.istreamLock.Unlock() - - return interceptedStream -} - -func (t *OutputStreamInterceptingFilter) CloseAllInterceptedStreams() { - for k := range t.streams { - t.closeInterceptedStream(k) - } -} diff --git a/stream_intercepting_tunnel.go b/stream_intercepting_tunnel.go deleted file mode 100644 index c12612f..0000000 --- a/stream_intercepting_tunnel.go +++ /dev/null @@ -1,88 +0,0 @@ -package guac - -import ( - "io" - - "github.com/sirupsen/logrus" -) - -var _ Tunnel = (*StreamInterceptingTunnel)(nil) - -type StreamInterceptingTunnel struct { - tunnel Tunnel - - outputStreamFilter *OutputStreamInterceptingFilter - inputStreamFilter *InputStreamInterceptingFilter -} - -func NewStreamInterceptingTunnel(tunnel Tunnel) *StreamInterceptingTunnel { - stream := &StreamInterceptingTunnel{tunnel: tunnel} - stream.outputStreamFilter = NewOutputStreamInterceptingFilter(stream) - stream.inputStreamFilter = NewInputStreamInterceptingFilter(stream) - return stream -} - -func (t *StreamInterceptingTunnel) AcquireReader() InstructionReader { - reader := t.tunnel.AcquireReader() - - reader = NewFilteredGuacamoleReader(reader, t.outputStreamFilter) - reader = NewFilteredGuacamoleReader(reader, t.inputStreamFilter) - - return reader -} - -func (t *StreamInterceptingTunnel) ReleaseReader() { - t.tunnel.ReleaseReader() -} - -func (t *StreamInterceptingTunnel) HasQueuedReaderThreads() bool { - return t.tunnel.HasQueuedReaderThreads() -} - -func (t *StreamInterceptingTunnel) AcquireWriter() io.Writer { - return t.tunnel.AcquireWriter() -} - -func (t *StreamInterceptingTunnel) ReleaseWriter() { - t.tunnel.ReleaseWriter() -} - -func (t *StreamInterceptingTunnel) HasQueuedWriterThreads() bool { - return t.tunnel.HasQueuedWriterThreads() -} - -func (t *StreamInterceptingTunnel) GetUUID() string { - return t.tunnel.GetUUID() -} - -func (t *StreamInterceptingTunnel) Close() error { - t.outputStreamFilter.CloseAllInterceptedStreams() - - return t.tunnel.Close() -} - -func (t *StreamInterceptingTunnel) ConnectionID() string { - return t.tunnel.ConnectionID() -} - -func (t *StreamInterceptingTunnel) InterceptOutputStream(idx int, output io.Writer) error { - logrus.Debugf("Intercepting output stream %d of tunnel %s", idx, t.tunnel.ConnectionID()) - - if err := t.outputStreamFilter.InterceptStream(idx, output); err != nil { - return err - } - - logrus.Debugf("Finished intercepting output stream %d of tunnel %s", idx, t.ConnectionID()) - return nil -} - -func (t *StreamInterceptingTunnel) InterceptInputStream(idx int, input io.Reader) error { - logrus.Debugf("Intercepting input stream %d of tunnel %s", idx, t.ConnectionID()) - - if err := t.inputStreamFilter.InterceptStream(idx, input); err != nil { - return err - } - - logrus.Debugf("Finished intercepting input stream %d of tunnel %s", idx, t.ConnectionID()) - return nil -} diff --git a/stream_test.go b/stream_test.go index f5f8bc5..8e73186 100644 --- a/stream_test.go +++ b/stream_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net" + "sync" "testing" "time" ) @@ -98,12 +99,16 @@ func TestInstructionReader_Flush(t *testing.T) { } type fakeConn struct { + lock sync.Mutex ToRead []byte + ToWrite []byte HasRead bool Closed bool } func (f *fakeConn) Read(b []byte) (n int, err error) { + f.lock.Lock() + defer f.lock.Unlock() if f.HasRead { return 0, io.EOF } else { @@ -113,7 +118,11 @@ func (f *fakeConn) Read(b []byte) (n int, err error) { } func (f *fakeConn) Write(b []byte) (n int, err error) { - return 0, nil + f.lock.Lock() + defer f.lock.Unlock() + + f.ToWrite = append(f.ToWrite, b...) + return len(b), nil } func (f *fakeConn) Close() error { diff --git a/tunnel.go b/tunnel.go index 0514ad5..4e16e31 100644 --- a/tunnel.go +++ b/tunnel.go @@ -10,7 +10,7 @@ import ( // Ensure SimpleTunnel implements the Tunnel interface var _ Tunnel = (*SimpleTunnel)(nil) -// // Ensure InstructionReader implements the InstructionReader interface +// Ensure InstructionReader implements the InstructionReader interface var _ InstructionReader = (*FilteredGuacamoleReader)(nil) // The Guacamole protocol instruction Opcode reserved for arbitrary @@ -125,25 +125,17 @@ func (t *SimpleTunnel) GetUUID() string { } type FilteredGuacamoleReader struct { - reader InstructionReader + InstructionReader filter Filter } func NewFilteredGuacamoleReader(reader InstructionReader, filter Filter) *FilteredGuacamoleReader { - return &FilteredGuacamoleReader{reader: reader, filter: filter} -} - -func (r *FilteredGuacamoleReader) Available() bool { - return r.reader.Available() -} - -func (r *FilteredGuacamoleReader) Flush() { - r.reader.Flush() + return &FilteredGuacamoleReader{reader, filter} } // ReadOne takes an instruction from the stream and parses it into an Instruction func (r *FilteredGuacamoleReader) ReadOne() (instruction *Instruction, err error) { - instructionBuffer, err := r.reader.ReadSome() + instructionBuffer, err := r.InstructionReader.ReadSome() if err != nil { return } diff --git a/tunnel_map.go b/tunnel_map.go index 306a5fd..97df076 100644 --- a/tunnel_map.go +++ b/tunnel_map.go @@ -62,7 +62,7 @@ type TunnelMap struct { tunnelTimeout time.Duration // Map of all tunnels that are using HTTP, indexed by tunnel UUID. - tunnelMap map[string]*LastAccessedTunnel + tunnelMap map[string]*LastAccessedTunnel } // NewTunnelMap creates a new TunnelMap and starts the scheduled job with the default timeout. diff --git a/user_tunnel.go b/user_tunnel.go new file mode 100644 index 0000000..47ab66a --- /dev/null +++ b/user_tunnel.go @@ -0,0 +1,47 @@ +package guac + +import "io" + +// Ensure UserTunnel implements Tunnel +var _ Tunnel = (*UserTunnel)(nil) + +type UserTunnel struct { + Tunnel + + outputFilter *OutputInterceptingFilter + inputFilter *InputInterceptingFilter +} + +func NewUserTunnel(tunnel Tunnel) *UserTunnel { + tun := &UserTunnel{Tunnel: tunnel} + + tun.inputFilter, tun.outputFilter = NewInputInterceptingFilter(tun), NewOutputInterceptingFilter(tun) + + return tun +} + +// InterceptOutputStream intercepts an output stream, i.e. when downloading +// a file you provide a http.ResponseWriter and InterceptOutputStream will +// pipe the stream numbers through it. +func (t *UserTunnel) InterceptOutputStream(id string, stream io.Writer) error { + return <-t.outputFilter.InterceptStream(id, stream) +} + +// InterceptInputStream intercepts an input stream, i.e. when uploading a file. +// For example you can pass a http.Request.Body() to inject a file in a Guacamole stream. +func (t *UserTunnel) InterceptInputStream(id string, stream io.Reader) error { + return <-t.inputFilter.InterceptStream(id, stream) +} + +// AcquireReader of UserTunnel wraps the original AcquireReader +// but it filters the instructions before handing them to the +// caller. +func (t *UserTunnel) AcquireReader() InstructionReader { + reader := t.Tunnel.AcquireReader() + + // Filter both for input and output streams + return NewFilteredInstructionReader( + NewFilteredInstructionReader(reader, t.inputFilter), + t.outputFilter, + ) +} From cb247563ca070cbc045e3db807efd1ed7864bdce Mon Sep 17 00:00:00 2001 From: Manuel Romei Date: Fri, 18 Feb 2022 09:37:51 +0100 Subject: [PATCH 3/3] test: fix input intercepting filter test Add the test with fragmentation on, i.e. with a long blob. --- input_intercepting_filter_test.go | 30 ++++++++++++++++++++++++++---- stream_test.go | 2 +- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/input_intercepting_filter_test.go b/input_intercepting_filter_test.go index 75a0dd2..9ab8094 100644 --- a/input_intercepting_filter_test.go +++ b/input_intercepting_filter_test.go @@ -2,6 +2,8 @@ package guac import ( "bytes" + "encoding/base64" + "fmt" "testing" "time" ) @@ -18,23 +20,43 @@ func TestInputInterceptingFilter(t *testing.T) { ), )) - toInject := []byte("Hello") + firstBlob := bytes.Repeat([]byte("A"), 4096) + secondBlob := bytes.Repeat([]byte("B"), 100) + + toInject := append(firstBlob, secondBlob...) // Hijack stream 1 and inject some data that will need to end up on the wire finished := f.InterceptStream("1", bytes.NewReader([]byte(toInject))) - // base64("Hello") = "SGVsbG8=" - if got, want := string(conn.ToWrite), "4.blob,1.1,8.SGVsbG8=;"; got != want { + encoded := base64.StdEncoding.EncodeToString(firstBlob) + + if got, want := string(conn.ToWrite), fmt.Sprintf("4.blob,1.1,%d.%s;", len(encoded), encoded); got != want { t.Fatalf("On the wire: %v, want=%v", got, want) } // Simulate an ACK from guacd f.Filter(NewInstruction("ack", "1", "", "0")) + + encoded = base64.StdEncoding.EncodeToString(secondBlob) + + if got, want := string(conn.ToWrite), fmt.Sprintf("4.blob,1.1,%d.%s;", len(encoded), encoded); got != want { + t.Fatalf("On the wire: %v, want=%v", got, want) + } + + // Simulate another ACK from guacd, the packet should have been + // fragmented in two: one which contains the first 4096 bytes + // base64 encoded, and the second which contains the remaining + // 100 bytes base64 encoded. + f.Filter(NewInstruction("ack", "1", "", "0")) + + // There shouldn't be any pending read, so finished should have + // completed by now, if not that's an error and this test should + // timeout. if err := <-finished; err != nil { t.Fatal(err) } - if got, want := string(conn.ToWrite), "4.blob,1.1,8.SGVsbG8=;3.end,1.1;"; got != want { + if got, want := string(conn.ToWrite), "3.end,1.1;"; got != want { t.Fatalf("On the wire: %v, want=%v", got, want) } }) diff --git a/stream_test.go b/stream_test.go index 8e73186..ec07033 100644 --- a/stream_test.go +++ b/stream_test.go @@ -121,7 +121,7 @@ func (f *fakeConn) Write(b []byte) (n int, err error) { f.lock.Lock() defer f.lock.Unlock() - f.ToWrite = append(f.ToWrite, b...) + f.ToWrite = b return len(b), nil }