@@ -36,6 +36,9 @@ type Client struct {
3636 * ssh.Client
3737 opts []tracing.Option
3838 capability tracingCapability
39+
40+ requestHandlersMu sync.Mutex
41+ requestHandlers map [string ]RequestHandlerFn
3942}
4043
4144type tracingCapability int
@@ -56,9 +59,10 @@ const (
5659// of whether they should provide tracing context.
5760func NewClient (c ssh.Conn , chans <- chan ssh.NewChannel , reqs <- chan * ssh.Request , opts ... tracing.Option ) * Client {
5861 clt := & Client {
59- Client : ssh .NewClient (c , chans , reqs ),
60- opts : opts ,
61- capability : tracingUnsupported ,
62+ Client : ssh .NewClient (c , chans , reqs ),
63+ opts : opts ,
64+ capability : tracingUnsupported ,
65+ requestHandlers : map [string ]RequestHandlerFn {},
6266 }
6367
6468 if bytes .HasPrefix (clt .ServerVersion (), []byte ("SSH-2.0-Teleport" )) {
@@ -89,7 +93,7 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err
8993 )
9094 defer span .End ()
9195
92- // create the wrapper while the lock is held
96+ // create a new wrapper to propagate tracing span context.
9397 wrapper := & clientWrapper {
9498 capability : c .capability ,
9599 Conn : c .Client .Conn ,
@@ -165,18 +169,6 @@ func (c *Client) OpenChannel(
165169// NewSession creates a new SSH session that is passed tracing context
166170// so that spans may be correlated properly over the ssh connection.
167171func (c * Client ) NewSession (ctx context.Context ) (* Session , error ) {
168- return c .newSession (ctx , nil )
169- }
170-
171- // NewSessionWithRequestCallback creates a new SSH session that is passed
172- // tracing context so that spans may be correlated properly over the ssh
173- // connection. The handling of channel requests from the underlying SSH
174- // session can be controlled with chanReqCallback.
175- func (c * Client ) NewSessionWithRequestCallback (ctx context.Context , chanReqCallback ChannelRequestCallback ) (* Session , error ) {
176- return c .newSession (ctx , chanReqCallback )
177- }
178-
179- func (c * Client ) newSession (ctx context.Context , chanReqCallback ChannelRequestCallback ) (* Session , error ) {
180172 tracer := tracing .NewConfig (c .opts ).TracerProvider .Tracer (instrumentationName )
181173
182174 ctx , span := tracer .Start (
@@ -194,7 +186,7 @@ func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestC
194186 )
195187 defer span .End ()
196188
197- // create the wrapper while the lock is still held
189+ // create a new wrapper to propagate tracing span context.
198190 wrapper := & clientWrapper {
199191 capability : c .capability ,
200192 Conn : c .Client .Conn ,
@@ -203,9 +195,92 @@ func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestC
203195 contexts : make (map [string ][]context.Context ),
204196 }
205197
206- // get a session from the wrapper
207- session , err := wrapper .NewSession (chanReqCallback )
208- return session , trace .Wrap (err )
198+ // open a session manually so we can take ownership of the
199+ // requests chan
200+ ch , reqs , err := wrapper .OpenChannel ("session" , nil )
201+ if err != nil {
202+ return nil , trace .Wrap (err )
203+ }
204+
205+ unhandledReqs := c .serveSessionRequests (ctx , reqs )
206+ session , err := newCryptoSSHSession (ch , unhandledReqs )
207+ if err != nil {
208+ _ = ch .Close ()
209+ return nil , trace .Wrap (err )
210+ }
211+
212+ // wrap the session so all session requests on the channel
213+ // can be traced
214+ return & Session {
215+ Session : session ,
216+ wrapper : wrapper ,
217+ }, nil
218+ }
219+
220+ // RequestHandlerFn is an ssh request handler function.
221+ type RequestHandlerFn func (ctx context.Context , ch * ssh.Request )
222+
223+ // HandleSessionRequest registers a handler for any incoming [ssh.Request] matching the
224+ // provided type within a session. If the type is already being handled, an error is returned.
225+ // All registered handlers are consumed by the next call to [Client.NewSession].
226+ func (c * Client ) HandleSessionRequest (ctx context.Context , requestType string , handlerFn RequestHandlerFn ) error {
227+ c .requestHandlersMu .Lock ()
228+ defer c .requestHandlersMu .Unlock ()
229+
230+ if _ , ok := c .requestHandlers [requestType ]; ok {
231+ return trace .AlreadyExists ("ssh request type %q is already being handled for this session" , requestType )
232+ }
233+
234+ c .requestHandlers [requestType ] = handlerFn
235+ return nil
236+ }
237+
238+ // serveSessionRequests from the remote side with registered handlers.
239+ //
240+ // This method consumes all registered handlers so that the next call to
241+ // [Client.NewSession] will not reuse the same handlers.
242+ func (c * Client ) serveSessionRequests (ctx context.Context , in <- chan * ssh.Request ) <- chan * ssh.Request {
243+ c .requestHandlersMu .Lock ()
244+ requestHandlers := c .requestHandlers
245+ c .requestHandlers = make (map [string ]RequestHandlerFn )
246+ c .requestHandlersMu .Unlock ()
247+
248+ // Capture requests not handled by registered request handlers and
249+ // pass them to the crypto [ssh.Session].
250+ unhandledReqs := make (chan * ssh.Request , cap (in ))
251+
252+ tracer := tracing .NewConfig (c .opts ).TracerProvider .Tracer (instrumentationName )
253+ go func () {
254+ defer close (unhandledReqs )
255+ for req := range in {
256+ ctx , span := tracer .Start (
257+ ctx ,
258+ fmt .Sprintf ("ssh.HandleRequests/%s" , req .Type ),
259+ oteltrace .WithSpanKind (oteltrace .SpanKindClient ),
260+ oteltrace .WithAttributes (
261+ append (
262+ peerAttr (c .Conn .RemoteAddr ()),
263+ semconv .RPCServiceKey .String ("ssh.Client" ),
264+ semconv .RPCMethodKey .String ("HandleRequests" ),
265+ semconv .RPCSystemKey .String ("ssh" ),
266+ )... ,
267+ ),
268+ )
269+
270+ handler , ok := requestHandlers [req .Type ]
271+ if ok {
272+ handler (ctx , req )
273+ } else {
274+ // Pass on requests without a registered handler. These will be
275+ // handled by the default x/crypto/ssh request handler.
276+ unhandledReqs <- req
277+ }
278+
279+ span .End ()
280+ }
281+ }()
282+
283+ return unhandledReqs
209284}
210285
211286// clientWrapper wraps the ssh.Conn for individual ssh.Client
@@ -229,64 +304,6 @@ type clientWrapper struct {
229304 contexts map [string ][]context.Context
230305}
231306
232- // ChannelRequestCallback allows the handling of channel requests
233- // to be customized. nil can be returned if you don't want
234- // golang/x/crypto/ssh to handle the request.
235- type ChannelRequestCallback func (req * ssh.Request ) * ssh.Request
236-
237- // NewSession opens a new Session for this client.
238- func (c * clientWrapper ) NewSession (callback ChannelRequestCallback ) (* Session , error ) {
239- // create a client that will defer to us when
240- // opening the "session" channel so that we
241- // can add an Envelope to the request
242- client := & ssh.Client {
243- Conn : c ,
244- }
245-
246- var session * ssh.Session
247- var err error
248- if callback != nil {
249- // open a session manually so we can take ownership of the
250- // requests chan
251- ch , originalReqs , openChannelErr := client .OpenChannel ("session" , nil )
252- if openChannelErr != nil {
253- return nil , trace .Wrap (openChannelErr )
254- }
255-
256- // pass the channel requests to the provided callback and
257- // forward them to another chan so golang.org/x/crypto/ssh
258- // can handle Session exiting correctly
259- reqs := make (chan * ssh.Request , cap (originalReqs ))
260- go func () {
261- defer close (reqs )
262-
263- for req := range originalReqs {
264- if req := callback (req ); req != nil {
265- reqs <- req
266- }
267- }
268- }()
269-
270- session , err = newCryptoSSHSession (ch , reqs )
271- if err != nil {
272- _ = ch .Close ()
273- return nil , trace .Wrap (err )
274- }
275- } else {
276- session , err = client .NewSession ()
277- if err != nil {
278- return nil , trace .Wrap (err )
279- }
280- }
281-
282- // wrap the session so all session requests on the channel
283- // can be traced
284- return & Session {
285- Session : session ,
286- wrapper : c ,
287- }, nil
288- }
289-
290307// wrappedSSHConn allows an SSH session to be created while also allowing
291308// callers to take ownership of the SSH channel requests chan.
292309type wrappedSSHConn struct {
0 commit comments