diff --git a/libproxy/proxy.go b/libproxy/proxy.go index 3ea3b7a..634b660 100644 --- a/libproxy/proxy.go +++ b/libproxy/proxy.go @@ -10,12 +10,33 @@ import ( "mime/multipart" "net/http" "net/url" + "os" + "runtime" "strconv" "strings" + "sync/atomic" + "time" "github.com/google/uuid" ) +// Initialize loggers +var ( + InfoLogger *log.Logger + ErrorLogger *log.Logger + DebugLogger *log.Logger +) + +// Health metrics +var ( + startTime time.Time + totalRequests uint64 + totalErrors uint64 + lastRequestTime time.Time + isServerRunning bool + healthCheckEnabled bool +) + type statusChangeFunction func(status string, isListening bool) var ( @@ -49,6 +70,48 @@ type Response struct { Headers map[string]string `json:"headers"` } +// HealthResponse contains the health check information +type HealthResponse struct { + Status string `json:"status"` + Uptime string `json:"uptime"` + StartTime time.Time `json:"startTime"` + TotalRequests uint64 `json:"totalRequests"` + TotalErrors uint64 `json:"totalErrors"` + LastRequestTime time.Time `json:"lastRequestTime,omitempty"` + Version string `json:"version"` + GoVersion string `json:"goVersion"` + NumGoroutine int `json:"numGoroutine"` + MemoryAllocated uint64 `json:"memoryAllocated"` + MemoryTotal uint64 `json:"memoryTotal"` + MemorySystemUsed uint64 `json:"memorySystemUsed"` +} + +func setupLoggers() { + // Create logs directory if it doesn't exist + logDir := "logs" + if _, err := os.Stat(logDir); os.IsNotExist(err) { + err := os.Mkdir(logDir, 0755) + if err != nil { + log.Println("Failed to create logs directory:", err) + } + } + + // Create log file with date in filename + currentTime := time.Now().Format("2006-01-02") + logFile, err := os.OpenFile(fmt.Sprintf("%s/proxy-%s.log", logDir, currentTime), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Println("Failed to open log file:", err) + } + + // Create multi-writer to write logs to both file and stdout + multiWriter := io.MultiWriter(os.Stdout, logFile) + + // Initialize loggers with different prefixes + InfoLogger = log.New(multiWriter, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile) + ErrorLogger = log.New(multiWriter, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) + DebugLogger = log.New(multiWriter, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile) +} + func isAllowedDest(dest string) bool { for _, b := range bannedDests { if b == dest { @@ -83,6 +146,15 @@ func Initialize( withSSL bool, finished chan bool, ) { + // Set up loggers first + setupLoggers() + + // Initialize health metrics + startTime = time.Now() + totalRequests = 0 + totalErrors = 0 + healthCheckEnabled = true + if initialBannedOutputs != "" { bannedOutputs = strings.Split(initialBannedOutputs, ",") } @@ -94,15 +166,23 @@ func Initialize( allowedOrigins = strings.Split(initialAllowedOrigins, ",") accessToken = initialAccessToken sessionFingerprint = uuid.New().String() - log.Println("Starting proxy server...") + InfoLogger.Println("Starting proxy server...") + // Register handlers http.HandleFunc("/", proxyHandler) + http.HandleFunc("/health", healthCheckHandler) + http.HandleFunc("/metrics", metricsHandler) if !withSSL { go func() { + InfoLogger.Printf("Attempting to listen on http://%s/", proxyURL) + isServerRunning = true httpServerError := http.ListenAndServe(proxyURL, nil) if httpServerError != nil { + isServerRunning = false + errorMsg := fmt.Sprintf("Server failed to start: %v", httpServerError) + ErrorLogger.Println(errorMsg) onStatusChange("An error occurred: "+httpServerError.Error(), false) } @@ -110,25 +190,34 @@ func Initialize( }() onStatusChange("Listening on http://"+proxyURL+"/", true) + InfoLogger.Printf("Proxy server listening on http://%s/", proxyURL) + InfoLogger.Printf("Health check available at http://%s/health", proxyURL) } else { onStatusChange("Checking SSL certificate...", false) + InfoLogger.Println("Checking SSL certificate...") err := EnsurePrivateKeyInstalled() if err != nil { - log.Println(err.Error()) - onStatusChange("An error occurred.", false) + ErrorLogger.Printf("SSL certificate error: %v", err) + onStatusChange("An SSL certificate error occurred: "+err.Error(), false) } go func() { + InfoLogger.Printf("Attempting to listen on https://%s/", proxyURL) + isServerRunning = true httpServerError := http.ListenAndServeTLS(proxyURL, GetOrCreateDataPath()+"/cert.pem", GetOrCreateDataPath()+"/key.pem", nil) if httpServerError != nil { - onStatusChange("An error occurred.", false) + isServerRunning = false + errorMsg := fmt.Sprintf("HTTPS server failed to start: %v", httpServerError) + ErrorLogger.Println(errorMsg) + onStatusChange("An error occurred: "+httpServerError.Error(), false) } }() onStatusChange("Listening on https://"+proxyURL+"/", true) - log.Println("Proxy server listening on https://" + proxyURL + "/") + InfoLogger.Printf("Proxy server listening on https://%s/", proxyURL) + InfoLogger.Printf("Health check available at https://%s/health", proxyURL) } } @@ -138,6 +227,91 @@ func GetAccessToken() string { func SetAccessToken(newAccessToken string) { accessToken = newAccessToken + InfoLogger.Println("Access token updated") +} + +// GetHealthStatus returns the current health status of the proxy +func GetHealthStatus() HealthResponse { + var memory runtime.MemStats + runtime.ReadMemStats(&memory) + + return HealthResponse{ + Status: getStatusString(), + Uptime: time.Since(startTime).String(), + StartTime: startTime, + TotalRequests: atomic.LoadUint64(&totalRequests), + TotalErrors: atomic.LoadUint64(&totalErrors), + LastRequestTime: lastRequestTime, + Version: "1.1.0", // Update with your version + GoVersion: runtime.Version(), + NumGoroutine: runtime.NumGoroutine(), + MemoryAllocated: memory.Alloc, + MemoryTotal: memory.TotalAlloc, + MemorySystemUsed: memory.Sys, + } +} + +func getStatusString() string { + if isServerRunning { + return "healthy" + } + return "unhealthy" +} + +// healthCheckHandler provides a simple health check endpoint +func healthCheckHandler(w http.ResponseWriter, r *http.Request) { + // Record this request but don't count it in metrics + w.Header().Set("Content-Type", "application/json") + + // Get health status + healthStatus := GetHealthStatus() + + // Set appropriate status code + if healthStatus.Status == "healthy" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + } + + // Write response + err := json.NewEncoder(w).Encode(healthStatus) + if err != nil { + ErrorLogger.Printf("Failed to encode health check response: %v", err) + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprintln(w, "{\"status\":\"error\",\"message\":\"Failed to generate health response\"}") + } + + InfoLogger.Printf("Health check from %s returned status: %s", r.RemoteAddr, healthStatus.Status) +} + +// metricsHandler provides detailed metrics for monitoring +func metricsHandler(w http.ResponseWriter, r *http.Request) { + // Require access token for metrics endpoint + authHeader := r.Header.Get("Authorization") + + // Check if token is provided and valid + if len(accessToken) > 0 && (authHeader != "Bearer "+accessToken) { + ErrorLogger.Printf("Unauthorized metrics access from %s", r.RemoteAddr) + w.WriteHeader(http.StatusUnauthorized) + _, _ = fmt.Fprintln(w, "{\"status\":\"error\",\"message\":\"Unauthorized access\"}") + return + } + + // Get health status which includes all metrics + healthStatus := GetHealthStatus() + + // Return detailed metrics + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + err := json.NewEncoder(w).Encode(healthStatus) + if err != nil { + ErrorLogger.Printf("Failed to encode metrics response: %v", err) + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprintln(w, "{\"status\":\"error\",\"message\":\"Failed to generate metrics response\"}") + } + + InfoLogger.Printf("Metrics requested from %s", r.RemoteAddr) } const ErrorBodyInvalidRequest = "{\"success\": false, \"data\":{\"message\":\"(Proxy Error) Invalid request.\"}}" @@ -145,37 +319,68 @@ const ErrorBodyProxyRequestFailed = "{\"success\": false, \"data\":{\"message\": const maxMemory = int64(32 << 20) // multipartRequestDataKey currently its 32 MB func proxyHandler(response http.ResponseWriter, request *http.Request) { + // Update request metrics + atomic.AddUint64(&totalRequests, 1) + lastRequestTime = time.Now() + + startTime := time.Now() + clientIP := request.RemoteAddr + method := request.Method + requestURL := request.URL.String() + userAgent := request.Header.Get("User-Agent") + + // Log incoming request + InfoLogger.Printf("Received %s request from %s for %s (User-Agent: %s)", method, clientIP, requestURL, userAgent) + + // Skip processing for health check and metrics paths if they come through here + if request.URL.Path == "/health" || request.URL.Path == "/metrics" { + return + } + // We want to allow all types of requests to the proxy, though we only want to allow certain // origins. response.Header().Add("Access-Control-Allow-Headers", "*") if request.Method == "OPTIONS" { response.Header().Add("Access-Control-Allow-Origin", "*") response.WriteHeader(200) + DebugLogger.Printf("Responded to OPTIONS request from %s in %v", clientIP, time.Since(startTime)) return } - if request.Header.Get("Origin") == "" || !isAllowedOrigin(request.Header.Get("Origin")) { + origin := request.Header.Get("Origin") + if origin == "" || !isAllowedOrigin(origin) { + atomic.AddUint64(&totalErrors, 1) if strings.HasPrefix(request.Header.Get("Content-Type"), "application/json") { response.Header().Add("Access-Control-Allow-Headers", "*") response.Header().Add("Access-Control-Allow-Origin", "*") response.WriteHeader(200) - _, _ = fmt.Fprintln(response, ErrorBodyProxyRequestFailed) + _, err := fmt.Fprintln(response, ErrorBodyProxyRequestFailed) + if err != nil { + ErrorLogger.Printf("Failed to write error response: %v", err) + } + ErrorLogger.Printf("Denied access to %s from disallowed origin: %s", clientIP, origin) return } - // If it is not an allowed origin, redirect back to hoppscotch.io. response.Header().Add("Location", "https://hoppscotch.io/") response.WriteHeader(301) + InfoLogger.Printf("Redirected request from %s with disallowed origin: %s", clientIP, origin) return } else { // Otherwise set the appropriate CORS policy and continue. - response.Header().Add("Access-Control-Allow-Origin", request.Header.Get("Origin")) + response.Header().Add("Access-Control-Allow-Origin", origin) + DebugLogger.Printf("Allowed request from origin: %s", origin) } // For anything other than an POST request, we'll return an empty JSON object. response.Header().Add("Content-Type", "application/json; charset=utf-8") if request.Method != "POST" { - _, _ = fmt.Fprintln(response, "{\"success\": true, \"data\":{\"sessionFingerprint\":\""+sessionFingerprint+"\", \"isProtected\":"+strconv.FormatBool(len(accessToken) > 0)+"}}") + _, err := fmt.Fprintln(response, "{\"success\": true, \"data\":{\"sessionFingerprint\":\""+sessionFingerprint+"\", \"isProtected\":"+strconv.FormatBool(len(accessToken) > 0)+"}}") + if err != nil { + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Failed to write non-POST response: %v", err) + } + InfoLogger.Printf("Responded to %s request from %s in %v", method, clientIP, time.Since(startTime)) return } @@ -186,36 +391,75 @@ func proxyHandler(response http.ResponseWriter, request *http.Request) { if multipartRequestDataKey == "" { multipartRequestDataKey = "proxyRequestData" } + if isMultipart { - var err = request.ParseMultipartForm(maxMemory) + DebugLogger.Printf("Processing multipart form data request from %s", clientIP) + err := request.ParseMultipartForm(maxMemory) if err != nil { - log.Printf("Failed to parse request body: %v", err) - _, _ = fmt.Fprintln(response, ErrorBodyInvalidRequest) + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Failed to parse multipart form from %s: %v", clientIP, err) + _, writeErr := fmt.Fprintln(response, ErrorBodyInvalidRequest) + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) + } return } - r := request.MultipartForm.Value[multipartRequestDataKey] + + if request.MultipartForm == nil || request.MultipartForm.Value == nil { + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Invalid multipart form from %s: MultipartForm or Value is nil", clientIP) + _, writeErr := fmt.Fprintln(response, ErrorBodyInvalidRequest) + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) + } + return + } + + r, exists := request.MultipartForm.Value[multipartRequestDataKey] + if !exists || len(r) == 0 { + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Invalid multipart form from %s: missing %s key", clientIP, multipartRequestDataKey) + _, writeErr := fmt.Fprintln(response, ErrorBodyInvalidRequest) + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) + } + return + } + err = json.Unmarshal([]byte(r[0]), &requestData) if err != nil || len(requestData.Url) == 0 || len(requestData.Method) == 0 { - // If the logged err is nil here, it means either the URL or method were not supplied - // in the request data. - log.Printf("Failed to parse request body: %v", err) - _, _ = fmt.Fprintln(response, ErrorBodyInvalidRequest) + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Failed to parse request body from %s: %v", clientIP, err) + _, writeErr := fmt.Fprintln(response, ErrorBodyInvalidRequest) + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) + } return } } else { - var err = json.NewDecoder(request.Body).Decode(&requestData) + DebugLogger.Printf("Processing JSON request from %s", clientIP) + err := json.NewDecoder(request.Body).Decode(&requestData) if err != nil || len(requestData.Url) == 0 || len(requestData.Method) == 0 { - // If the logged err is nil here, it means either the URL or method were not supplied - // in the request data. - log.Printf("Failed to parse request body: %v", err) - _, _ = fmt.Fprintln(response, ErrorBodyInvalidRequest) + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Failed to parse JSON request body from %s: %v", clientIP, err) + _, writeErr := fmt.Fprintln(response, ErrorBodyInvalidRequest) + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) + } return } } + // Log the proxied request details + InfoLogger.Printf("Proxying %s request from %s to %s", requestData.Method, clientIP, requestData.Url) + if len(accessToken) > 0 && requestData.AccessToken != accessToken { - log.Print("An unauthorized request was made.") - _, _ = fmt.Fprintln(response, "{\"success\": false, \"data\":{\"message\":\"(Proxy Error) Unauthorized request; you may need to set your access token in Settings.\"}}") + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Unauthorized request from %s: Invalid access token", clientIP) + _, err := fmt.Fprintln(response, "{\"success\": false, \"data\":{\"message\":\"(Proxy Error) Unauthorized request; you may need to set your access token in Settings.\"}}") + if err != nil { + ErrorLogger.Printf("Failed to write unauthorized error: %v", err) + } return } @@ -223,12 +467,28 @@ func proxyHandler(response http.ResponseWriter, request *http.Request) { var proxyRequest http.Request proxyRequest.Header = make(http.Header) proxyRequest.Method = requestData.Method - proxyRequest.URL, _ = url.Parse(requestData.Url) + + // Parse URL and check for errors + parsedURL, err := url.Parse(requestData.Url) + if err != nil { + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Invalid URL from %s: %v", clientIP, err) + _, writeErr := fmt.Fprintln(response, "{\"success\": false, \"data\":{\"message\":\"(Proxy Error) Invalid URL: "+err.Error()+"\"}}") + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) + } + return + } + proxyRequest.URL = parsedURL // Block requests to illegal destinations if !isAllowedDest(proxyRequest.URL.Hostname()) { - log.Print("A request to a banned destination was made.") - _, _ = fmt.Fprintln(response, "{\"success\": false, \"data\":{\"message\":\"(Proxy Error) Request cannot be to this destination.\"}}") + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Request to banned destination %s from %s", proxyRequest.URL.Hostname(), clientIP) + _, err := fmt.Fprintln(response, "{\"success\": false, \"data\":{\"message\":\"(Proxy Error) Request cannot be to this destination.\"}}") + if err != nil { + ErrorLogger.Printf("Failed to write banned destination error: %v", err) + } return } @@ -241,7 +501,9 @@ func proxyHandler(response http.ResponseWriter, request *http.Request) { if len(requestData.Auth.Username) > 0 && len(requestData.Auth.Password) > 0 { proxyRequest.SetBasicAuth(requestData.Auth.Username, requestData.Auth.Password) + DebugLogger.Printf("Using basic auth for request to %s", proxyRequest.URL.String()) } + for k, v := range requestData.Headers { proxyRequest.Header.Set(k, v) } @@ -252,114 +514,186 @@ func proxyHandler(response http.ResponseWriter, request *http.Request) { if len(strings.TrimSpace(proxyRequest.Header.Get("User-Agent"))) < 1 { // If there is no valid user agent specified at all, *then* use the default. - // We'll do this for now, we could look at using the User-Agent from whatever made the request. proxyRequest.Header.Set("User-Agent", "Proxyscotch/1.1") } if isMultipart { body := &bytes.Buffer{} writer := multipart.NewWriter(body) + + // Process form fields for key := range request.MultipartForm.Value { if key == multipartRequestDataKey { continue } for _, val := range request.MultipartForm.Value[key] { - // This usually never happens, mostly memory issue err := writer.WriteField(key, val) if err != nil { - log.Printf("Failed to write multipart field key: %s error: %v", key, err) + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Failed to write multipart field key: %s error: %v", key, err) + _, writeErr := fmt.Fprintln(response, ErrorBodyProxyRequestFailed) + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) + } return } } } + + // Process files for fileKey := range request.MultipartForm.File { for _, val := range request.MultipartForm.File[fileKey] { f, err := val.Open() if err != nil { - log.Printf("Failed to write multipart field: %s err: %v", fileKey, err) + ErrorLogger.Printf("Failed to open file %s: %v", val.Filename, err) continue } - field, _ := writer.CreatePart(val.Header) - _, err = io.Copy(field, f) + + field, err := writer.CreatePart(val.Header) if err != nil { - log.Printf("Failed to write multipart field: %s err: %v", fileKey, err) - } - // Close need not be handled, as go will clear temp file - defer func(f multipart.File) { - err := f.Close() + ErrorLogger.Printf("Failed to create part for file %s: %v", val.Filename, err) + err = f.Close() if err != nil { - log.Printf("Failed to close file") + ErrorLogger.Printf("Failed to close file: %v", err) } - }(f) + continue + } + + _, err = io.Copy(field, f) + if err != nil { + ErrorLogger.Printf("Failed to copy file %s: %v", val.Filename, err) + } + + // Close file + err = f.Close() + if err != nil { + ErrorLogger.Printf("Failed to close file %s: %v", val.Filename, err) + } } } + err := writer.Close() if err != nil { - log.Printf("Failed to write multipart content: %v", err) - _, _ = fmt.Fprintf(response, ErrorBodyProxyRequestFailed) - if err != nil { - return + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Failed to finalize multipart content: %v", err) + _, writeErr := fmt.Fprintln(response, ErrorBodyProxyRequestFailed) + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) } return } + contentType := fmt.Sprintf("multipart/form-data; boundary=%v", writer.Boundary()) proxyRequest.Header.Set("content-type", contentType) proxyRequest.Body = io.NopCloser(bytes.NewReader(body.Bytes())) proxyRequest.ContentLength = int64(len(body.Bytes())) - _ = proxyRequest.Body.Close() } else if len(requestData.Data) > 0 { proxyRequest.Body = io.NopCloser(strings.NewReader(requestData.Data)) proxyRequest.ContentLength = int64(len(requestData.Data)) - _ = proxyRequest.Body.Close() } - var client http.Client - var proxyResponse *http.Response - proxyResponse, err := client.Do(&proxyRequest) + // Create client with timeout + var client = &http.Client{ + Timeout: 30 * time.Second, + } + DebugLogger.Printf("Sending proxied request to %s", proxyRequest.URL.String()) + proxyStartTime := time.Now() + + // Send request to target server + proxyResponse, err := client.Do(&proxyRequest) if err != nil { - log.Print("Failed to write response body: ", err.Error()) - _, _ = fmt.Fprintln(response, ErrorBodyProxyRequestFailed) + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Failed to execute proxied request to %s: %v", proxyRequest.URL.String(), err) + _, writeErr := fmt.Fprintln(response, "{\"success\": false, \"data\":{\"message\":\"(Proxy Error) Request failed: "+err.Error()+"\"}}") + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) + } return } + // Ensure body is closed after use + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + ErrorLogger.Printf("Failed to close response body: %v", err) + } + }(proxyResponse.Body) + + InfoLogger.Printf("Received response from %s with status %d in %v", + proxyRequest.URL.String(), + proxyResponse.StatusCode, + time.Since(proxyStartTime)) + + // Build response data var responseData Response responseData.Success = true responseData.Status = proxyResponse.StatusCode responseData.StatusText = strings.Join(strings.Split(proxyResponse.Status, " ")[1:], " ") - responseBytes, _ := io.ReadAll(proxyResponse.Body) + + // Read response body + responseBytes, err := io.ReadAll(proxyResponse.Body) + if err != nil { + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Failed to read response body: %v", err) + _, writeErr := fmt.Fprintln(response, ErrorBodyProxyRequestFailed) + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) + } + return + } + responseData.Headers = headerToArray(proxyResponse.Header) if requestData.WantsBinary { + // Redact banned outputs for _, bannedOutput := range bannedOutputs { responseBytes = bytes.ReplaceAll(responseBytes, []byte(bannedOutput), []byte("[redacted]")) } - // If using the new binary format, encode the response body. + // Encode binary response responseData.Data = base64.RawStdEncoding.EncodeToString(responseBytes) responseData.IsBinary = true + DebugLogger.Printf("Returning binary response of %d bytes", len(responseBytes)) } else { - // Otherwise, simply return the old format. + // Return as string responseData.Data = string(responseBytes) + // Redact banned outputs for _, bannedOutput := range bannedOutputs { responseData.Data = strings.Replace(responseData.Data, bannedOutput, "[redacted]", -1) } + DebugLogger.Printf("Returning text response of %d bytes", len(responseData.Data)) } - // Write the request body to the response. + // Write the response err = json.NewEncoder(response).Encode(responseData) - - // Return the response. if err != nil { - log.Print("Failed to write response body: ", err.Error()) - _, _ = fmt.Fprintln(response, ErrorBodyProxyRequestFailed) + atomic.AddUint64(&totalErrors, 1) + ErrorLogger.Printf("Failed to encode response: %v", err) + _, writeErr := fmt.Fprintln(response, ErrorBodyProxyRequestFailed) + if writeErr != nil { + ErrorLogger.Printf("Failed to write error response: %v", writeErr) + } return } + + // Log completion + InfoLogger.Printf("Completed %s request from %s to %s in %v", + requestData.Method, + clientIP, + requestData.Url, + time.Since(startTime)) +} + +// EnableHealthCheck turns on or off the health check functionality +func EnableHealthCheck(enable bool) { + healthCheckEnabled = enable + InfoLogger.Printf("Health check endpoint %s", map[bool]string{true: "enabled", false: "disabled"}[enable]) } -// / Converts http.Header to a map. -// / Original Source: https://stackoverflow.com/a/37030039/2872279 (modified). +// Converts http.Header to a map. +// Original Source: https://stackoverflow.com/a/37030039/2872279 (modified). func headerToArray(header http.Header) (res map[string]string) { res = make(map[string]string) diff --git a/libproxy/proxy_test.go b/libproxy/proxy_test.go index 3214ffe..8d68ce1 100644 --- a/libproxy/proxy_test.go +++ b/libproxy/proxy_test.go @@ -5,12 +5,14 @@ import ( "encoding/base64" "encoding/json" "fmt" + "github.com/google/uuid" "github.com/mccutchen/go-httpbin/v2/httpbin" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "net/url" "testing" + "time" ) type RespResult struct { @@ -45,11 +47,22 @@ func getResult(_req Request, origin string) RespResult { } func init() { + // Setup loggers first to avoid nil pointers + setupLoggers() + allowedOrigins = []string{"validorigin1.com", "validorigin2.com"} + bannedDests = []string{"banned.example.com"} app := httpbin.New() testServer := httptest.NewServer(app.Handler()) testServerUrl = testServer.URL + + // Initialize health metrics for testing + startTime = time.Now() + totalRequests = 0 + totalErrors = 0 + healthCheckEnabled = true + isServerRunning = true } func checkErrorNUnmarshalHTTPBinResponse(data string, t *testing.T) HTTPBinResponse { @@ -315,10 +328,50 @@ func TestInvalidAccessTokenRequestShouldFail(t *testing.T) { assert.Equal(t, "false", fmt.Sprintf("%v", success)) } -//func TestBannedOutputs(t *testing.T) { -// // TODO -// // need clear understanding on banned outputs -//} +func TestBannedOutputs(t *testing.T) { + // Test redaction of banned outputs + bannedOutputs = []string{"SECRET_TOKEN"} + defer func() { + bannedOutputs = []string{} // cleanup + }() + + request := Request{ + Method: "GET", + Url: testServerUrl + "/response-headers?response=Contains_SECRET_TOKEN_Value", + WantsBinary: false, + } + + resp := getResultDef(request) + assert.Equal(t, 200, resp.proxyResponse.Code) + // Check if the secret token is redacted + assert.NotContains(t, resp.requestResponse.Data, "SECRET_TOKEN") + assert.Contains(t, resp.requestResponse.Data, "[redacted]") +} + +func TestBannedDestination(t *testing.T) { + // Test blocking requests to banned destinations + request := Request{ + Method: "GET", + Url: "https://banned.example.com/resource", + } + + resp := getResultDef(request) + assert.Equal(t, 200, resp.proxyResponse.Code) + + var proxyRespParse map[string]interface{} + err := json.NewDecoder(resp.proxyResponse.Body).Decode(&proxyRespParse) + assert.Nil(t, err) + + success := proxyRespParse["success"] + assert.Equal(t, "false", fmt.Sprintf("%v", success)) + + // Check for the banned destination error message + data, ok := proxyRespParse["data"].(map[string]interface{}) + assert.True(t, ok) + message, ok := data["message"].(string) + assert.True(t, ok) + assert.Contains(t, message, "cannot be to this destination") +} func TestBasicAuth(t *testing.T) { request := Request{ @@ -335,12 +388,10 @@ func TestBasicAuth(t *testing.T) { resp := getResultDef(request) assert.Equal(t, 200, resp.requestResponse.Status) checkErrorNUnmarshalHTTPBinResponse(resp.requestResponse.Data, t) - } func TestBasicAuthIncorrectParams(t *testing.T) { // just to confirm above auth is working fine if username and password is sent wrong - request := Request{ Method: "GET", Url: testServerUrl + "/basic-auth/username/password2", @@ -351,3 +402,157 @@ func TestBasicAuthIncorrectParams(t *testing.T) { assert.Equal(t, 401, resp.requestResponse.Status) checkErrorNUnmarshalHTTPBinResponse(resp.requestResponse.Data, t) } + +func TestHealthCheckHandler(t *testing.T) { + // Test the health check endpoint + request := httptest.NewRequest("GET", "/health", nil) + recorder := httptest.NewRecorder() + + healthCheckHandler(recorder, request) + + result := recorder.Result() + defer result.Body.Close() + + assert.Equal(t, http.StatusOK, result.StatusCode) + + var healthResponse HealthResponse + err := json.NewDecoder(result.Body).Decode(&healthResponse) + assert.Nil(t, err) + + assert.Equal(t, "healthy", healthResponse.Status) + assert.NotEmpty(t, healthResponse.Uptime) + assert.Equal(t, startTime.Unix(), healthResponse.StartTime.Unix()) +} + +func TestMetricsHandler(t *testing.T) { + // Test metrics handler with no auth token + request := httptest.NewRequest("GET", "/metrics", nil) + recorder := httptest.NewRecorder() + + metricsHandler(recorder, request) + + result := recorder.Result() + defer result.Body.Close() + + assert.Equal(t, http.StatusOK, result.StatusCode) + + var metricsResponse HealthResponse + err := json.NewDecoder(result.Body).Decode(&metricsResponse) + assert.Nil(t, err) + + // Now test with auth token required + accessToken = "metrics-token" + defer func() { + accessToken = "" // cleanup + }() + + // Test without providing token (should fail) + recorder = httptest.NewRecorder() + metricsHandler(recorder, request) + assert.Equal(t, http.StatusUnauthorized, recorder.Result().StatusCode) + + // Test with correct token + request.Header.Set("Authorization", "Bearer metrics-token") + recorder = httptest.NewRecorder() + metricsHandler(recorder, request) + assert.Equal(t, http.StatusOK, recorder.Result().StatusCode) +} + +func TestSessionFingerprint(t *testing.T) { + // Test that a unique session fingerprint is generated + oldFingerprint := sessionFingerprint + + // Save original values to restore later + origInfoLogger := InfoLogger + origErrorLogger := ErrorLogger + origDebugLogger := DebugLogger + + // Ensure loggers are set up + if InfoLogger == nil || ErrorLogger == nil || DebugLogger == nil { + setupLoggers() + } + + // Instead of using Initialize which might cause issues in tests, + // we'll just generate a new UUID for the session fingerprint + sessionFingerprint = uuid.New().String() + + assert.NotEmpty(t, sessionFingerprint) + assert.NotEqual(t, oldFingerprint, sessionFingerprint) + + // Test that the fingerprint is returned in non-POST requests + request := httptest.NewRequest("GET", "/", nil) + request.Header.Set("Origin", "validorigin1.com") + recorder := httptest.NewRecorder() + + proxyHandler(recorder, request) + + var response map[string]interface{} + err := json.NewDecoder(recorder.Body).Decode(&response) + assert.Nil(t, err) + + data, ok := response["data"].(map[string]interface{}) + assert.True(t, ok) + returnedFingerprint, ok := data["sessionFingerprint"].(string) + assert.True(t, ok) + assert.Equal(t, sessionFingerprint, returnedFingerprint) + + // Restore original loggers if they were nil + if origInfoLogger == nil || origErrorLogger == nil || origDebugLogger == nil { + InfoLogger = origInfoLogger + ErrorLogger = origErrorLogger + DebugLogger = origDebugLogger + } +} + +func TestCustomUserAgent(t *testing.T) { + // Test that a custom User-Agent is passed through + customUA := "CustomUserAgent/1.0" + + resp := getResultDef(Request{ + Method: "GET", + Url: testServerUrl + "/get", + Headers: map[string]string{ + "User-Agent": customUA, + }, + }) + + httpBinResponse := checkErrorNUnmarshalHTTPBinResponse(resp.requestResponse.Data, t) + assert.Equal(t, customUA, httpBinResponse.Headers.Get("User-Agent")) + + // Test that default User-Agent is used when none is provided + resp = getResultDef(Request{ + Method: "GET", + Url: testServerUrl + "/get", + }) + + httpBinResponse = checkErrorNUnmarshalHTTPBinResponse(resp.requestResponse.Data, t) + assert.Equal(t, "Proxyscotch/1.1", httpBinResponse.Headers.Get("User-Agent")) +} + +func TestEnableHealthCheck(t *testing.T) { + // Test enabling and disabling the health check + oldValue := healthCheckEnabled + defer func() { + healthCheckEnabled = oldValue + }() + + EnableHealthCheck(false) + assert.False(t, healthCheckEnabled) + + EnableHealthCheck(true) + assert.True(t, healthCheckEnabled) +} + +func TestHeaderToArray(t *testing.T) { + // Test the headerToArray function + header := http.Header{} + header.Add("Content-Type", "application/json") + header.Add("X-Custom-Header", "value1") + header.Add("X-Custom-Header", "value2") // Multiple values for the same key + + result := headerToArray(header) + + // It should only keep the last value for each key + assert.Equal(t, "application/json", result["content-type"]) + assert.Equal(t, "value2", result["x-custom-header"]) +}