Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚀 feat: add Load-Shedding Middleware for Request Timeout Management #3264

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
45 changes: 45 additions & 0 deletions middleware/loadshedding/loadshedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package loadshedding

import (
"context"
"time"

"github.com/gofiber/fiber/v3"
)

// New creates a middleware handler enforces a timeout on request processing to manage server load.
// If a request exceeds the specified timeout, a custom load-shedding handler is executed.
func New(timeout time.Duration, loadSheddingHandler fiber.Handler, exclude func(fiber.Ctx) bool) fiber.Handler {
return func(c fiber.Ctx) error {
// Skip load-shedding logic for requests matching the exclusion criteria
if exclude != nil && exclude(c) {
return c.Next()
}

// Create a context with a timeout for the current request
ctx, cancel := context.WithTimeout(c.Context(), timeout)
defer cancel()

// Set the new context with a timeout
c.SetContext(ctx)

// Process the request and capture any error
err := c.Next()

// Create a channel to signal when request processing completes
done := make(chan error, 1)

// Send the result of the request processing to the channel
go func() {
done <- err
}()

// Handle either request completion or timeout
select {
case <-ctx.Done(): // Triggered if the timeout expires
return loadSheddingHandler(c)
case err := <-done: // Triggered if request processing completes
return err
}
}
}
92 changes: 92 additions & 0 deletions middleware/loadshedding/loadshedding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package loadshedding_test

import (
"net/http/httptest"
"testing"
"time"

"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/loadshedding"
"github.com/stretchr/testify/require"
)

// Helper handlers
func successHandler(c fiber.Ctx) error {
return c.SendString("Request processed successfully!")
}

func timeoutHandler(c fiber.Ctx) error {
time.Sleep(2 * time.Second) // Simulate a long-running request
return c.SendString("This should not appear")
}

func loadSheddingHandler(c fiber.Ctx) error {
return c.Status(fiber.StatusServiceUnavailable).SendString("Service Overloaded")
}

func excludedHandler(c fiber.Ctx) error {
return c.SendString("Excluded route")
}

// go test -run Test_LoadSheddingExcluded
func Test_LoadSheddingExcluded(t *testing.T) {
t.Parallel()
app := fiber.New()

// Middleware with exclusion
app.Use(loadshedding.New(
1*time.Second,
loadSheddingHandler,
func(c fiber.Ctx) bool { return c.Path() == "/excluded" },
))
app.Get("/", successHandler)
app.Get("/excluded", excludedHandler)

// Test excluded route
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/excluded", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}

// go test -run Test_LoadSheddingTimeout
func Test_LoadSheddingTimeout(t *testing.T) {
t.Parallel()
app := fiber.New()

// Middleware with a 1-second timeout
app.Use(loadshedding.New(
1*time.Second, // Middleware timeout
loadSheddingHandler,
nil,
))
app.Get("/", timeoutHandler)

// Create a custom request
req := httptest.NewRequest(fiber.MethodGet, "/", nil)

// Test timeout behavior
resp, err := app.Test(req, fiber.TestConfig{
Timeout: 3 * time.Second, // Ensure the test timeout exceeds middleware timeout
})
require.NoError(t, err)
require.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode)
}

// go test -run Test_LoadSheddingSuccessfulRequest
func Test_LoadSheddingSuccessfulRequest(t *testing.T) {
t.Parallel()
app := fiber.New()

// Middleware with sufficient time for request to complete
app.Use(loadshedding.New(
2*time.Second,
loadSheddingHandler,
nil,
))
app.Get("/", successHandler)

// Test successful request
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}