Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -1041,8 +1041,13 @@ func (app *App) ErrorHandler(ctx *Ctx, err error) error {
mountedPrefixParts int
)

for prefix, subApp := range app.mountFields.appList {
if prefix != "" && strings.HasPrefix(ctx.path, prefix) {
normalizedPath := utils.AddTrailingSlash(ctx.Path())

for _, prefix := range app.mountFields.appListKeys {
subApp := app.mountFields.appList[prefix]
normalizedPrefix := utils.AddTrailingSlash(prefix)

if prefix != "" && strings.HasPrefix(normalizedPath, normalizedPrefix) {
parts := len(strings.Split(prefix, "/"))
if mountedPrefixParts <= parts {
if subApp.configured.ErrorHandler != nil {
Expand Down
67 changes: 67 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1977,3 +1977,70 @@ func Benchmark_Ctx_AcquireReleaseFlow(b *testing.B) {
}
})
}

func TestErrorHandler_PicksRightOne(t *testing.T) {
// common handler to be used by all routes,
// it will always fail by returning an error since
// we need to test that the right ErrorHandler is invoked
handler := func(c *Ctx) error {
return errors.New("random error")
}

// subapp /api/v1/users [no custom error handler]
appAPIV1Users := New()
appAPIV1Users.Get("/", handler)

// subapp /api/v1/use [with custom error handler]
appAPIV1UseEH := func(c *Ctx, _ error) error {
return c.SendString("/api/v1/use error handler")
}
appAPIV1Use := New(Config{ErrorHandler: appAPIV1UseEH})
appAPIV1Use.Get("/", handler)

// subapp: /api/v1 [with custom error handler]
appV1EH := func(c *Ctx, _ error) error {
return c.SendString("/api/v1 error handler")
}
appV1 := New(Config{ErrorHandler: appV1EH})
appV1.Get("/", handler)
appV1.Mount("/users", appAPIV1Users)
appV1.Mount("/use", appAPIV1Use)

// root app [no custom error handler]
app := New()
app.Get("/", handler)
app.Mount("/api/v1", appV1)

testCases := []struct {
path string // the endpoint url to test
expected string // the expected error response
}{
// /api/v1/users mount doesn't have custom ErrorHandler
// so it should use the upper-nearest one (/api/v1)
{"/api/v1/users", "/api/v1 error handler"},

// /api/v1/users mount has a custom ErrorHandler
{"/api/v1/use", "/api/v1/use error handler"},

// /api/v1 mount has a custom ErrorHandler
{"/api/v1", "/api/v1 error handler"},

// / mount doesn't have custom ErrorHandler, since is
// the root path i will use Fiber's default Error Handler
{"/", "random error"},
}

for _, testCase := range testCases {
resp, err := app.Test(httptest.NewRequest(MethodGet, testCase.path, nil))
if err != nil {
t.Fatal(err)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

utils.AssertEqual(t, testCase.expected, string(body))
}
}
11 changes: 11 additions & 0 deletions utils/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package utils

import "strings"

// ToLower converts ascii string to lower-case
func ToLower(b string) string {
res := make([]byte, len(b))
Expand Down Expand Up @@ -73,3 +75,12 @@ func EqualFold(b, s string) bool {
}
return true
}

// AddTrailingSlash appends a trailing '/' to v if it does not already end with one
func AddTrailingSlash(s string) string {
if strings.HasSuffix(s, "/") {
return s
}

return s + "/"
}
63 changes: 63 additions & 0 deletions utils/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,66 @@ func Test_EqualFold(t *testing.T) {
res = EqualFold("/MY4/NAME/IS/:PARAM/*", "/my4/nAME/IS/:param/*")
AssertEqual(t, true, res)
}

func Test_AddTrailingSlash(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want string
}{
{
name: "already has trailing slash",
in: "path/",
want: "path/",
},
{
name: "no trailing slash",
in: "path",
want: "path/",
},
{
name: "empty string",
in: "",
want: "/",
},
{
name: "root slash",
in: "/",
want: "/",
},
{
name: "multi-level path",
in: "a/b/c",
want: "a/b/c/",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := AddTrailingSlash(tt.in)
if got != tt.want {
t.Fatalf("AddTrailingSlash(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}

func Benchmark_AddTrailingSlash(b *testing.B) {
cases := map[string]string{
"AlreadyHasSlash": "example/path/",
"NoSlash": "example/path",
"Empty": "",
"LongString": strings.Repeat("a", 10_000),
"LongStringWithSlash": strings.Repeat("a", 10_000) + "/",
}

for name, input := range cases {
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = AddTrailingSlash(input)
}
})
}
}