-
Notifications
You must be signed in to change notification settings - Fork 94
/
Copy pathconnection.go
119 lines (109 loc) · 3.24 KB
/
connection.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package main
import (
"sync"
"github.com/jaksi/sshutils"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"golang.org/x/crypto/ssh"
)
type connContext struct {
ssh.ConnMetadata
cfg *config
noMoreSessions bool
}
type channelContext struct {
connContext
channelID int
}
var channelHandlers = map[string]func(newChannel ssh.NewChannel, context channelContext) error{
"session": handleSessionChannel,
"direct-tcpip": handleDirectTCPIPChannel,
}
var (
sshConnectionsMetric = promauto.NewCounter(prometheus.CounterOpts{
Name: "sshesame_ssh_connections_total",
Help: "Total number of SSH connections",
})
activeSSHConnectionsMetric = promauto.NewGauge(prometheus.GaugeOpts{
Name: "sshesame_active_ssh_connections",
Help: "Number of active SSH connections",
})
unknownChannelsMetric = promauto.NewCounter(prometheus.CounterOpts{
Name: "sshesame_unknown_channels_total",
Help: "Total number of unknown channels",
})
)
func handleConnection(conn *sshutils.Conn, cfg *config) {
sshConnectionsMetric.Inc()
activeSSHConnectionsMetric.Inc()
defer activeSSHConnectionsMetric.Dec()
var channels sync.WaitGroup
context := connContext{ConnMetadata: conn, cfg: cfg}
defer func() {
conn.Close()
channels.Wait()
context.logEvent(connectionCloseLog{})
}()
context.logEvent(connectionLog{
ClientVersion: string(conn.ClientVersion()),
})
hostKeysPayload := make([][]byte, len(cfg.parsedHostKeys))
for i, key := range cfg.parsedHostKeys {
hostKeysPayload[i] = key.PublicKey().Marshal()
}
if _, _, err := conn.SendRequest("[email protected]", false, marshalBytes(hostKeysPayload)); err != nil {
warningLogger.Printf("Failed to send [email protected] request: %v", err)
return
}
channelID := 0
for conn.Requests != nil || conn.NewChannels != nil {
select {
case request, ok := <-conn.Requests:
if !ok {
conn.Requests = nil
continue
}
context.logEvent(debugGlobalRequestLog{
RequestType: request.Type,
WantReply: request.WantReply,
Payload: string(request.Payload),
})
if err := handleGlobalRequest(request, &context); err != nil {
warningLogger.Printf("Failed to handle global request: %v", err)
conn.Requests = nil
continue
}
case newChannel, ok := <-conn.NewChannels:
if !ok {
conn.NewChannels = nil
continue
}
context.logEvent(debugChannelLog{
channelLog: channelLog{ChannelID: channelID},
ChannelType: newChannel.ChannelType(),
ExtraData: string(newChannel.ExtraData()),
})
channelType := newChannel.ChannelType()
handler := channelHandlers[channelType]
if handler == nil {
unknownChannelsMetric.Inc()
warningLogger.Printf("Unsupported channel type %v", channelType)
if err := newChannel.Reject(ssh.ConnectionFailed, "open failed"); err != nil {
warningLogger.Printf("Failed to reject channel: %v", err)
conn.NewChannels = nil
continue
}
continue
}
channels.Add(1)
go func(context channelContext) {
defer channels.Done()
if err := handler(newChannel, context); err != nil {
warningLogger.Printf("Failed to handle new channel: %v", err)
conn.Close()
}
}(channelContext{context, channelID})
channelID++
}
}
}