Skip to content

Commit bafb299

Browse files
authored
Refactor tracessh.Client to cover untraced paths - session requests (#59291)
* Move ssh session helper into a separate file. * Replace ssh session ChannelRequestCallback with a more x/crypto/ssh-like approach to request channel handling. * Update comments. * Move new request handling logic into tracessh client. * Fix goroutine leaks with refactors. * Add test. * Cleanup. * Rename methods. * Fix lint. * Address comments; Fix error and cleanup handling issues in tracing ssh tests. * Increase test timeouts. * Fix test data race.
1 parent 6d4961a commit bafb299

File tree

9 files changed

+299
-172
lines changed

9 files changed

+299
-172
lines changed

api/observability/tracing/ssh/client.go

Lines changed: 95 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ type Client struct {
3636
*ssh.Client
3737
opts []tracing.Option
3838
capability tracingCapability
39+
40+
requestHandlersMu sync.Mutex
41+
requestHandlers map[string]RequestHandlerFn
3942
}
4043

4144
type tracingCapability int
@@ -56,9 +59,10 @@ const (
5659
// of whether they should provide tracing context.
5760
func NewClient(c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request, opts ...tracing.Option) *Client {
5861
clt := &Client{
59-
Client: ssh.NewClient(c, chans, reqs),
60-
opts: opts,
61-
capability: tracingUnsupported,
62+
Client: ssh.NewClient(c, chans, reqs),
63+
opts: opts,
64+
capability: tracingUnsupported,
65+
requestHandlers: map[string]RequestHandlerFn{},
6266
}
6367

6468
if bytes.HasPrefix(clt.ServerVersion(), []byte("SSH-2.0-Teleport")) {
@@ -89,7 +93,7 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err
8993
)
9094
defer span.End()
9195

92-
// create the wrapper while the lock is held
96+
// create a new wrapper to propagate tracing span context.
9397
wrapper := &clientWrapper{
9498
capability: c.capability,
9599
Conn: c.Client.Conn,
@@ -165,18 +169,6 @@ func (c *Client) OpenChannel(
165169
// NewSession creates a new SSH session that is passed tracing context
166170
// so that spans may be correlated properly over the ssh connection.
167171
func (c *Client) NewSession(ctx context.Context) (*Session, error) {
168-
return c.newSession(ctx, nil)
169-
}
170-
171-
// NewSessionWithRequestCallback creates a new SSH session that is passed
172-
// tracing context so that spans may be correlated properly over the ssh
173-
// connection. The handling of channel requests from the underlying SSH
174-
// session can be controlled with chanReqCallback.
175-
func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) {
176-
return c.newSession(ctx, chanReqCallback)
177-
}
178-
179-
func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) {
180172
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)
181173

182174
ctx, span := tracer.Start(
@@ -194,7 +186,7 @@ func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestC
194186
)
195187
defer span.End()
196188

197-
// create the wrapper while the lock is still held
189+
// create a new wrapper to propagate tracing span context.
198190
wrapper := &clientWrapper{
199191
capability: c.capability,
200192
Conn: c.Client.Conn,
@@ -203,9 +195,92 @@ func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestC
203195
contexts: make(map[string][]context.Context),
204196
}
205197

206-
// get a session from the wrapper
207-
session, err := wrapper.NewSession(chanReqCallback)
208-
return session, trace.Wrap(err)
198+
// open a session manually so we can take ownership of the
199+
// requests chan
200+
ch, reqs, err := wrapper.OpenChannel("session", nil)
201+
if err != nil {
202+
return nil, trace.Wrap(err)
203+
}
204+
205+
unhandledReqs := c.serveSessionRequests(ctx, reqs)
206+
session, err := newCryptoSSHSession(ch, unhandledReqs)
207+
if err != nil {
208+
_ = ch.Close()
209+
return nil, trace.Wrap(err)
210+
}
211+
212+
// wrap the session so all session requests on the channel
213+
// can be traced
214+
return &Session{
215+
Session: session,
216+
wrapper: wrapper,
217+
}, nil
218+
}
219+
220+
// RequestHandlerFn is an ssh request handler function.
221+
type RequestHandlerFn func(ctx context.Context, ch *ssh.Request)
222+
223+
// HandleSessionRequest registers a handler for any incoming [ssh.Request] matching the
224+
// provided type within a session. If the type is already being handled, an error is returned.
225+
// All registered handlers are consumed by the next call to [Client.NewSession].
226+
func (c *Client) HandleSessionRequest(ctx context.Context, requestType string, handlerFn RequestHandlerFn) error {
227+
c.requestHandlersMu.Lock()
228+
defer c.requestHandlersMu.Unlock()
229+
230+
if _, ok := c.requestHandlers[requestType]; ok {
231+
return trace.AlreadyExists("ssh request type %q is already being handled for this session", requestType)
232+
}
233+
234+
c.requestHandlers[requestType] = handlerFn
235+
return nil
236+
}
237+
238+
// serveSessionRequests from the remote side with registered handlers.
239+
//
240+
// This method consumes all registered handlers so that the next call to
241+
// [Client.NewSession] will not reuse the same handlers.
242+
func (c *Client) serveSessionRequests(ctx context.Context, in <-chan *ssh.Request) <-chan *ssh.Request {
243+
c.requestHandlersMu.Lock()
244+
requestHandlers := c.requestHandlers
245+
c.requestHandlers = make(map[string]RequestHandlerFn)
246+
c.requestHandlersMu.Unlock()
247+
248+
// Capture requests not handled by registered request handlers and
249+
// pass them to the crypto [ssh.Session].
250+
unhandledReqs := make(chan *ssh.Request, cap(in))
251+
252+
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)
253+
go func() {
254+
defer close(unhandledReqs)
255+
for req := range in {
256+
ctx, span := tracer.Start(
257+
ctx,
258+
fmt.Sprintf("ssh.HandleRequests/%s", req.Type),
259+
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
260+
oteltrace.WithAttributes(
261+
append(
262+
peerAttr(c.Conn.RemoteAddr()),
263+
semconv.RPCServiceKey.String("ssh.Client"),
264+
semconv.RPCMethodKey.String("HandleRequests"),
265+
semconv.RPCSystemKey.String("ssh"),
266+
)...,
267+
),
268+
)
269+
270+
handler, ok := requestHandlers[req.Type]
271+
if ok {
272+
handler(ctx, req)
273+
} else {
274+
// Pass on requests without a registered handler. These will be
275+
// handled by the default x/crypto/ssh request handler.
276+
unhandledReqs <- req
277+
}
278+
279+
span.End()
280+
}
281+
}()
282+
283+
return unhandledReqs
209284
}
210285

211286
// clientWrapper wraps the ssh.Conn for individual ssh.Client
@@ -229,64 +304,6 @@ type clientWrapper struct {
229304
contexts map[string][]context.Context
230305
}
231306

232-
// ChannelRequestCallback allows the handling of channel requests
233-
// to be customized. nil can be returned if you don't want
234-
// golang/x/crypto/ssh to handle the request.
235-
type ChannelRequestCallback func(req *ssh.Request) *ssh.Request
236-
237-
// NewSession opens a new Session for this client.
238-
func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, error) {
239-
// create a client that will defer to us when
240-
// opening the "session" channel so that we
241-
// can add an Envelope to the request
242-
client := &ssh.Client{
243-
Conn: c,
244-
}
245-
246-
var session *ssh.Session
247-
var err error
248-
if callback != nil {
249-
// open a session manually so we can take ownership of the
250-
// requests chan
251-
ch, originalReqs, openChannelErr := client.OpenChannel("session", nil)
252-
if openChannelErr != nil {
253-
return nil, trace.Wrap(openChannelErr)
254-
}
255-
256-
// pass the channel requests to the provided callback and
257-
// forward them to another chan so golang.org/x/crypto/ssh
258-
// can handle Session exiting correctly
259-
reqs := make(chan *ssh.Request, cap(originalReqs))
260-
go func() {
261-
defer close(reqs)
262-
263-
for req := range originalReqs {
264-
if req := callback(req); req != nil {
265-
reqs <- req
266-
}
267-
}
268-
}()
269-
270-
session, err = newCryptoSSHSession(ch, reqs)
271-
if err != nil {
272-
_ = ch.Close()
273-
return nil, trace.Wrap(err)
274-
}
275-
} else {
276-
session, err = client.NewSession()
277-
if err != nil {
278-
return nil, trace.Wrap(err)
279-
}
280-
}
281-
282-
// wrap the session so all session requests on the channel
283-
// can be traced
284-
return &Session{
285-
Session: session,
286-
wrapper: c,
287-
}, nil
288-
}
289-
290307
// wrappedSSHConn allows an SSH session to be created while also allowing
291308
// callers to take ownership of the SSH channel requests chan.
292309
type wrappedSSHConn struct {

0 commit comments

Comments
 (0)