Skip to content

Commit ec371e4

Browse files
committed
fix: Connection counting
1 parent af48bba commit ec371e4

15 files changed

+1744
-82
lines changed

.gitignore

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ coverage.html
6666
/.gocache/
6767
/.claude
6868

69-
/msgtausch-stats.db
70-
/msgtausch-stats.db-shm
71-
/msgtausch-stats.db-wal
69+
msgtausch-stats*.db
70+
msgtausch-stats*.db-shm
71+
msgtausch-stats*.db-wal
72+
msgtausch_stats*.db
73+
msgtausch_stats*.db-shm
74+
msgtausch_stats*.db-wal
75+
76+
test_proxy_stats*.db
77+
test_proxy_stats*.db-shm
78+
test_proxy_stats*.db-wal

flake.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
go = pkgs.go_1_24;
4040
goVersion = "1.24";
4141
# Let the builder vendor dependencies internally, but ignore any in-tree vendor/
42-
vendorHash = "sha256-rFUVAUivUxhDHo/COi5mfX3Mfoqhfma3MuRn48Sxuqg=";
42+
vendorHash = "sha256-a2mjJVwYYkSATYC/EH5qqfCNqzXVefqdYJZutuuiZHA=";
4343
stripVendor = true;
4444
subPackages = [ "." ];
4545
env.CGO_ENABLED = 1;

msgtausch-srv/dashboard/dashboard.go

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@ package dashboard
33
import (
44
"encoding/json"
55
"net/http"
6-
"sync"
76
"time"
87

9-
"github.com/codefionn/msgtausch/msgtausch-srv/config"
108
"github.com/codefionn/msgtausch/msgtausch-srv/logger"
119
"github.com/codefionn/msgtausch/msgtausch-srv/stats"
1210
)
@@ -20,20 +18,6 @@ func writeJSON(w http.ResponseWriter, data interface{}) {
2018
}
2119
}
2220

23-
// Dashboard provides a web interface for viewing proxy statistics
24-
type Dashboard struct {
25-
config *config.Config
26-
collector stats.Collector
27-
28-
mutex sync.RWMutex
29-
cache *cache
30-
}
31-
32-
type cache struct {
33-
lastUpdate time.Time
34-
data *Data
35-
}
36-
3721
// Data represents the statistics data for the dashboard
3822
type Data struct {
3923
Overview *stats.OverviewStats `json:"overview"`

msgtausch-srv/logger/logger.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"log"
66
"os"
77
"strings"
8+
"sync/atomic"
89
)
910

1011
// LogLevel represents the severity of a log message
@@ -25,19 +26,28 @@ const (
2526
)
2627

2728
var (
28-
// currentLevel is the current logging level
29-
currentLevel LogLevel = INFO
29+
// currentLevel holds the current logging level atomically
30+
currentLevel atomic.Int32
3031
// stdLogger is the standard logger instance
3132
stdLogger = log.New(os.Stdout, "", log.LstdFlags)
3233
)
3334

35+
func init() {
36+
currentLevel.Store(int32(INFO))
37+
}
38+
3439
// SetLevel sets the current logging level
3540
func SetLevel(level LogLevel) {
36-
currentLevel = level
41+
currentLevel.Store(int32(level))
3742
}
3843

3944
func IsLevelEnabled(level LogLevel) bool {
40-
return level >= currentLevel
45+
return level >= LogLevel(currentLevel.Load())
46+
}
47+
48+
// GetLevel returns the current logging level.
49+
func GetLevel() LogLevel {
50+
return LogLevel(currentLevel.Load())
4151
}
4252

4353
// GetLevelFromString converts a string level to LogLevel
@@ -82,7 +92,7 @@ func levelToString(level LogLevel) string {
8292

8393
// logMessage logs a message at the specified level with optional context
8494
func logMessage(level LogLevel, format string, v ...any) {
85-
if level < currentLevel {
95+
if level < LogLevel(currentLevel.Load()) {
8696
return
8797
}
8898

msgtausch-srv/logger/logger_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,16 @@ func TestSetLevel(t *testing.T) {
4242
}
4343

4444
// Save the original level to restore it after the test
45-
originalLevel := currentLevel
45+
originalLevel := GetLevel()
4646
defer func() {
47-
currentLevel = originalLevel
47+
SetLevel(originalLevel)
4848
}()
4949

5050
for _, tt := range tests {
5151
t.Run(tt.name, func(t *testing.T) {
5252
SetLevel(tt.level)
53-
if currentLevel != tt.expectedLevel {
54-
t.Errorf("SetLevel() = %v, want %v", currentLevel, tt.expectedLevel)
53+
if GetLevel() != tt.expectedLevel {
54+
t.Errorf("SetLevel() = %v, want %v", GetLevel(), tt.expectedLevel)
5555
}
5656
})
5757
}
@@ -139,15 +139,15 @@ func TestLogLevelFiltering(t *testing.T) {
139139
}
140140

141141
// Save the original level to restore it after the test
142-
originalLevel := currentLevel
142+
originalLevel := GetLevel()
143143
defer func() {
144-
currentLevel = originalLevel
144+
SetLevel(originalLevel)
145145
}()
146146

147147
for _, tt := range tests {
148148
t.Run(tt.name, func(t *testing.T) {
149149
// Set the current log level
150-
currentLevel = tt.currentLevel
150+
SetLevel(tt.currentLevel)
151151

152152
// Capture the output
153153
output := captureOutput(func() {
@@ -164,7 +164,7 @@ func TestLogLevelFiltering(t *testing.T) {
164164
case FATAL:
165165
// Special case for FATAL to avoid os.Exit
166166
// We're just testing the filtering logic, not the exit behavior
167-
if tt.logLevel >= currentLevel {
167+
if IsLevelEnabled(FATAL) {
168168
stdLogger.Printf("[%s] %s", levelToString(FATAL), "test message")
169169
}
170170
}
@@ -223,9 +223,9 @@ func TestLogFormatting(t *testing.T) {
223223
}
224224

225225
// Save the original level to restore it after the test
226-
originalLevel := currentLevel
226+
originalLevel := GetLevel()
227227
defer func() {
228-
currentLevel = originalLevel
228+
SetLevel(originalLevel)
229229
}()
230230

231231
// Set level to DEBUG to ensure all messages are logged
@@ -254,9 +254,9 @@ func TestLogFormatting(t *testing.T) {
254254
// TestFatalBehavior tests the formatting of Fatal messages without actually calling os.Exit
255255
func TestFatalBehavior(t *testing.T) {
256256
// Save the original level to restore it after the test
257-
originalLevel := currentLevel
257+
originalLevel := GetLevel()
258258
defer func() {
259-
currentLevel = originalLevel
259+
SetLevel(originalLevel)
260260
}()
261261

262262
// Set level to DEBUG to ensure the message is logged

msgtausch-srv/proxy/proxy.go

Lines changed: 101 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package proxy
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"net"
@@ -730,6 +731,33 @@ func isClosedConnError(err error) bool {
730731
return strings.Contains(err.Error(), "use of closed network connection")
731732
}
732733

734+
// copyWithIdleTimeout copies data from src to dst, enforcing an idle timeout.
735+
// If no data is received for the given timeout, the copy returns with a timeout error.
736+
func copyWithIdleTimeout(dst, src net.Conn, timeout time.Duration) error {
737+
buf := make([]byte, 32*1024)
738+
for {
739+
_ = src.SetReadDeadline(time.Now().Add(timeout))
740+
n, rerr := src.Read(buf)
741+
if n > 0 {
742+
if _, werr := dst.Write(buf[:n]); werr != nil {
743+
if !isClosedConnError(werr) {
744+
return werr
745+
}
746+
return nil
747+
}
748+
}
749+
if rerr != nil {
750+
if ne, ok := rerr.(net.Error); ok && ne.Timeout() {
751+
return ne
752+
}
753+
if errors.Is(rerr, io.EOF) || isClosedConnError(rerr) {
754+
return nil
755+
}
756+
return rerr
757+
}
758+
}
759+
}
760+
733761
func (p *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
734762
ctx := r.Context()
735763
targetAddr := r.Host
@@ -978,14 +1006,31 @@ func (p *Server) forwardRequest(w http.ResponseWriter, r *http.Request, client *
9781006
isWebSocketResponse := resp.StatusCode == http.StatusSwitchingProtocols &&
9791007
strings.ToLower(resp.Header.Get("Upgrade")) == "websocket"
9801008

981-
if isWebSocketRequest && isWebSocketResponse {
1009+
if (isWebSocketRequest || isProxiedWebSocket) && isWebSocketResponse {
1010+
// For WebSocket upgrades, record the HTTP request/response before switching to tunnel
1011+
if p.proxy.Collector != nil && connectionID > 0 {
1012+
logger.Debug("Recording WebSocket upgrade request/response for connectionID=%d", connectionID)
1013+
responseHeaderSize := estimateHTTPResponseHeaderSize(resp)
1014+
if err := p.proxy.Collector.RecordHTTPResponseWithHeaders(r.Context(), connectionID, resp.StatusCode, resp.ContentLength, responseHeaderSize); err != nil {
1015+
logger.Error("Failed to record HTTP response: %v", err)
1016+
}
1017+
contentLength := r.ContentLength
1018+
if contentLength < 0 {
1019+
contentLength = 0
1020+
}
1021+
requestHeaderSize := estimateHTTPRequestHeaderSize(r)
1022+
if err := p.proxy.Collector.RecordHTTPRequestWithHeaders(ctx, connectionID, r.Method, targetURL, targetHost, r.UserAgent(), contentLength, requestHeaderSize); err != nil {
1023+
logger.Error("Failed to record HTTP request: %v", err)
1024+
}
1025+
}
9821026
logger.Debug("Handling WebSocket upgrade response from %s", targetHost)
983-
p.handleWebSocketTunnel(w, r, resp, client, connectionID)
1027+
p.handleWebSocketTunnel(w, r, resp, client)
9841028
return
9851029
}
9861030

9871031
if p.proxy.Collector != nil && connectionID > 0 {
9881032
responseHeaderSize := estimateHTTPResponseHeaderSize(resp)
1033+
logger.Debug("Recording HTTP response for connectionID=%d status=%d", connectionID, resp.StatusCode)
9891034
if err := p.proxy.Collector.RecordHTTPResponseWithHeaders(r.Context(), connectionID, resp.StatusCode, resp.ContentLength, responseHeaderSize); err != nil {
9901035
logger.Error("Failed to record HTTP response: %v", err)
9911036
}
@@ -997,7 +1042,8 @@ func (p *Server) forwardRequest(w http.ResponseWriter, r *http.Request, client *
9971042
contentLength = 0
9981043
}
9991044
requestHeaderSize := estimateHTTPRequestHeaderSize(r)
1000-
if err := p.proxy.Collector.RecordHTTPRequestWithHeaders(ctx, connectionID, r.Method, r.URL.RequestURI(), targetHost, r.UserAgent(), contentLength, requestHeaderSize); err != nil {
1045+
logger.Debug("Recording HTTP request for connectionID=%d method=%s url=%s", connectionID, r.Method, targetURL)
1046+
if err := p.proxy.Collector.RecordHTTPRequestWithHeaders(ctx, connectionID, r.Method, targetURL, targetHost, r.UserAgent(), contentLength, requestHeaderSize); err != nil {
10011047
logger.Error("Failed to record HTTP request: %v", err)
10021048
}
10031049
}
@@ -1012,9 +1058,18 @@ func (p *Server) forwardRequest(w http.ResponseWriter, r *http.Request, client *
10121058
if _, err := io.Copy(w, resp.Body); err != nil {
10131059
logger.Error("Failed to copy response body: %v", err)
10141060
}
1061+
1062+
// If statistics collection is enabled, proactively close idle upstream
1063+
// connections so trackedConn.Close triggers EndConnection in tests.
1064+
// This is gated by stats being enabled to avoid impacting keep-alive tests.
1065+
if p.proxy != nil && p.proxy.GetConfig() != nil && p.proxy.GetConfig().Statistics.Enabled {
1066+
if tr, ok := client.Transport.(*http.Transport); ok && tr != nil {
1067+
tr.CloseIdleConnections()
1068+
}
1069+
}
10151070
}
10161071

1017-
func (p *Server) handleWebSocketTunnel(w http.ResponseWriter, r *http.Request, resp *http.Response, client *http.Client, connectionID int64) {
1072+
func (p *Server) handleWebSocketTunnel(w http.ResponseWriter, r *http.Request, resp *http.Response, client *http.Client) {
10181073
hj, ok := w.(http.Hijacker)
10191074
if !ok {
10201075
logger.Error("HTTP server does not support hijacking for WebSocket")
@@ -1076,7 +1131,8 @@ func (p *Server) handleWebSocketTunnel(w http.ResponseWriter, r *http.Request, r
10761131
logger.Error("Failed to connect to WebSocket server or proxy: %v", err)
10771132
return
10781133
}
1079-
targetConn = newTrackedConn(r.Context(), targetConn, p.proxy, connectionID)
1134+
// targetConn is already created via Transport.DialContext which uses
1135+
// createForwardTCPClient and returns a tracked connection. Avoid double-wrapping.
10801136

10811137
logger.Debug("WebSocket tunnel established for %s", targetHost)
10821138

@@ -1106,6 +1162,13 @@ func (p *Server) handleWebSocketTunnel(w http.ResponseWriter, r *http.Request, r
11061162
}
11071163

11081164
var wg sync.WaitGroup
1165+
// Honor global timeout for tunnel lifetime
1166+
tunnelTimeout := time.Duration(p.config.TimeoutSeconds) * time.Second
1167+
if tunnelTimeout <= 0 {
1168+
tunnelTimeout = 30 * time.Second
1169+
}
1170+
ctx, cancel := context.WithTimeout(context.Background(), tunnelTimeout)
1171+
defer cancel()
11091172
wg.Add(2)
11101173

11111174
go func() {
@@ -1128,8 +1191,11 @@ func (p *Server) handleWebSocketTunnel(w http.ResponseWriter, r *http.Request, r
11281191
}
11291192
}
11301193

1131-
_, err := io.Copy(targetConn, clientConn)
1132-
if err != nil && !isClosedConnError(err) {
1194+
idle := time.Duration(p.config.TimeoutSeconds) * time.Second
1195+
if idle <= 0 {
1196+
idle = 30 * time.Second
1197+
}
1198+
if err := copyWithIdleTimeout(targetConn, clientConn, idle); err != nil && !isClosedConnError(err) {
11331199
logger.Error("Failed to copy client to target: %v", err)
11341200
}
11351201
}()
@@ -1141,12 +1207,22 @@ func (p *Server) handleWebSocketTunnel(w http.ResponseWriter, r *http.Request, r
11411207
logger.Error("Error closing client connection: %v", closeErr)
11421208
}
11431209
}()
1144-
_, err := io.Copy(clientConn, targetConn)
1145-
if err != nil && !isClosedConnError(err) {
1210+
idle := time.Duration(p.config.TimeoutSeconds) * time.Second
1211+
if idle <= 0 {
1212+
idle = 30 * time.Second
1213+
}
1214+
if err := copyWithIdleTimeout(clientConn, targetConn, idle); err != nil && !isClosedConnError(err) {
11461215
logger.Error("Failed to copy target to client: %v", err)
11471216
}
11481217
}()
11491218

1219+
// Force-close on timeout
1220+
go func() {
1221+
<-ctx.Done()
1222+
clientConn.Close()
1223+
targetConn.Close()
1224+
}()
1225+
11501226
wg.Wait()
11511227
logger.Debug("WebSocket tunnel closed for %s", targetHost)
11521228
}
@@ -1311,8 +1387,12 @@ func (p *Server) handleConnect(w http.ResponseWriter, r *http.Request, connectio
13111387
var wg sync.WaitGroup
13121388
wg.Add(2)
13131389

1314-
// Create a context to coordinate tunnel shutdown
1315-
ctx, cancel := context.WithCancel(context.Background())
1390+
// Create a context to coordinate tunnel shutdown, honor global timeout
1391+
tunnelTimeout := time.Duration(p.config.TimeoutSeconds) * time.Second
1392+
if tunnelTimeout <= 0 {
1393+
tunnelTimeout = 30 * time.Second
1394+
}
1395+
ctx, cancel := context.WithTimeout(context.Background(), tunnelTimeout)
13161396
defer cancel()
13171397

13181398
go func() {
@@ -1326,7 +1406,11 @@ func (p *Server) handleConnect(w http.ResponseWriter, r *http.Request, connectio
13261406
return
13271407
}
13281408
}
1329-
if _, err := io.Copy(targetConn, clientConn); err != nil {
1409+
idle := time.Duration(p.config.TimeoutSeconds) * time.Second
1410+
if idle <= 0 {
1411+
idle = 30 * time.Second
1412+
}
1413+
if err := copyWithIdleTimeout(targetConn, clientConn, idle); err != nil {
13301414
if !isClosedConnError(err) {
13311415
logger.Warn("TCP tunnel copy error (client to target): %v", err)
13321416
}
@@ -1342,7 +1426,11 @@ func (p *Server) handleConnect(w http.ResponseWriter, r *http.Request, connectio
13421426
go func() {
13431427
defer wg.Done()
13441428
defer cancel() // Cancel context when this goroutine exits
1345-
if _, err := io.Copy(clientConn, targetConn); err != nil {
1429+
idle := time.Duration(p.config.TimeoutSeconds) * time.Second
1430+
if idle <= 0 {
1431+
idle = 30 * time.Second
1432+
}
1433+
if err := copyWithIdleTimeout(clientConn, targetConn, idle); err != nil {
13461434
if !isClosedConnError(err) {
13471435
logger.Warn("TCP tunnel copy error (target to client): %v", err)
13481436
}

0 commit comments

Comments
 (0)