Skip to content

Commit

Permalink
feat(ui): path prefix support via HTTP header
Browse files Browse the repository at this point in the history
Makes the web app honour the `X-Forwarded-Prefix` HTTP request header that may be sent by a reverse-proxy in order to inform the app that its public routes contain a path prefix.
For instance this allows to serve the webapp via a reverse-proxy/ingress controller under a path prefix/sub path such as e.g. `/localai/` while still being able to use the regular LocalAI routes/paths without prefix when directly connecting to the LocalAI server.

Changes:
* Add new `StripPathPrefix` middleware to strip the path prefix (provided with the `X-Forwarded-Prefix` HTTP request header) from the request path prior to matching the HTTP route.
* Add a `BaseURL` utility function to build the base URL, honouring the `X-Forwarded-Prefix` HTTP request header.
* Generate the derived base URL into the HTML (`head.html` template) as `<base/>` tag.
* Make all webapp-internal URLs (within HTML+JS) relative in order to make the browser resolve them against the `<base/>` URL specified within each HTML page's header.
* Make font URLs within the CSS files relative to the CSS file.
* Generate redirect location URLs using the new `BaseURL` function.
* Use the new `BaseURL` function to generate absolute URLs within gallery JSON responses.

Closes #3095

TL;DR:
The header-based approach allows to move the path prefix configuration concern completely to the reverse-proxy/ingress as opposed to having to align the path prefix configuration between LocalAI, the reverse-proxy and potentially other internal LocalAI clients.
The gofiber swagger handler already supports path prefixes this way, see https://github.com/gofiber/swagger/blob/e2d9e9916d8809e8b23c4365f8acfbbd8a71c4cd/swagger.go#L79

Signed-off-by: Max Goltzsche <[email protected]>
  • Loading branch information
mgoltzsche committed Dec 26, 2024
1 parent 5c29e0c commit 8cdff73
Show file tree
Hide file tree
Showing 35 changed files with 456 additions and 104 deletions.
2 changes: 2 additions & 0 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ func API(application *application.Application) (*fiber.App, error) {

router := fiber.New(fiberCfg)

router.Use(middleware.StripPathPrefix())

router.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"
if listenData.TLS {
Expand Down
43 changes: 43 additions & 0 deletions core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,35 @@ func postInvalidRequest(url string) (error, int) {
return nil, resp.StatusCode
}

func getRequest(url string, header http.Header) (error, int, []byte) {

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err, -1, nil
}

req.Header = header

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err, -1, nil
}

defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return err, -1, nil
}

if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode, nil
}

return nil, resp.StatusCode, body
}

const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b87640e8644b09c2aee6e3b/raw/f0e8c26bb72edc16d9fbafbfd6638072126ff225/bert-embeddings-gallery.yaml`

//go:embed backend-assets/*
Expand Down Expand Up @@ -345,6 +374,20 @@ var _ = Describe("API test", func() {
})
})

Context("URL routing Tests", func() {
It("Should support reverse-proxy", func() {

err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
"X-Forwarded-Prefix": {"/myprefix/"},
})
Expect(err).To(BeNil(), "error")
Expect(sc).To(Equal(200), "status code")
Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body")
})
})

Context("Applying models", func() {

It("applies models from a gallery", func() {
Expand Down
6 changes: 3 additions & 3 deletions core/http/elements/buttons.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func installButton(galleryName string) elem.Node {
"class": "float-right inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "/browse/install/model/" + galleryName,
"hx-post": "browse/install/model/" + galleryName,
},
elem.I(
attrs.Props{
Expand All @@ -36,7 +36,7 @@ func reInstallButton(galleryName string) elem.Node {
"hx-target": "#action-div-" + dropBadChars(galleryName),
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "/browse/install/model/" + galleryName,
"hx-post": "browse/install/model/" + galleryName,
},
elem.I(
attrs.Props{
Expand Down Expand Up @@ -80,7 +80,7 @@ func deleteButton(galleryID string) elem.Node {
"hx-target": "#action-div-" + dropBadChars(galleryID),
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "/browse/delete/model/" + galleryID,
"hx-post": "browse/delete/model/" + galleryID,
},
elem.I(
attrs.Props{
Expand Down
2 changes: 1 addition & 1 deletion core/http/elements/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func searchableElement(text, icon string) elem.Node {
// "value": text,
//"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2",
"href": "#!",
"hx-post": "/browse/search/models",
"hx-post": "browse/search/models",
"hx-target": "#search-results",
// TODO: this doesn't work
// "hx-vals": `{ \"search\": \"` + text + `\" }`,
Expand Down
4 changes: 2 additions & 2 deletions core/http/elements/progressbar.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func StartProgressBar(uid, progress, text string) string {
return elem.Div(
attrs.Props{
"hx-trigger": "done",
"hx-get": "/browse/job/" + uid,
"hx-get": "browse/job/" + uid,
"hx-swap": "outerHTML",
"hx-target": "this",
},
Expand All @@ -77,7 +77,7 @@ func StartProgressBar(uid, progress, text string) string {
},
elem.Text(bluemonday.StrictPolicy().Sanitize(text)), //Perhaps overly defensive
elem.Div(attrs.Props{
"hx-get": "/browse/job/progress/" + uid,
"hx-get": "browse/job/progress/" + uid,
"hx-trigger": "every 600ms",
"hx-target": "this",
"hx-swap": "innerHTML",
Expand Down
2 changes: 2 additions & 0 deletions core/http/endpoints/explorer/dashboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/internal"
)

Expand All @@ -14,6 +15,7 @@ func Dashboard() func(*fiber.Ctx) error {
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": utils.BaseURL(c),
}

if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
Expand Down
6 changes: 4 additions & 2 deletions core/http/endpoints/localai/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -82,7 +83,8 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fibe
Galleries: mgs.galleries,
ConfigURL: input.ConfigURL,
}
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})

return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}

Expand All @@ -105,7 +107,7 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fib
return err
}

return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}

Expand Down
2 changes: 2 additions & 0 deletions core/http/endpoints/localai/welcome.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
Expand Down Expand Up @@ -32,6 +33,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": utils.BaseURL(c),
"Models": modelsWithoutConfig,
"ModelsConfig": backendConfigs,
"GalleryConfig": galleryConfigs,
Expand Down
2 changes: 2 additions & 0 deletions core/http/explorer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
)

Expand All @@ -22,6 +23,7 @@ func Explorer(db *explorer.Database) *fiber.App {

app := fiber.New(fiberCfg)

app.Use(middleware.StripPathPrefix())
routes.RegisterExplorerRoutes(app, db)

httpFS := http.FS(embedDirStatic)
Expand Down
36 changes: 36 additions & 0 deletions core/http/middleware/strippathprefix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package middleware

import (
"strings"

"github.com/gofiber/fiber/v2"
)

// StripPathPrefix returns a middleware that strips a path prefix from the request path.
// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
func StripPathPrefix() fiber.Handler {
return func(c *fiber.Ctx) error {
for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] {
if prefix != "" {
path := c.Path()
pos := len(prefix)

if prefix[pos-1] == '/' {
pos--
} else {
prefix += "/"
}

if strings.HasPrefix(path, prefix) {
c.Path(path[pos:])
break
} else if prefix[:pos] == path {
c.Redirect(prefix)
return nil
}
}
}

return c.Next()
}
}
121 changes: 121 additions & 0 deletions core/http/middleware/strippathprefix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package middleware

import (
"net/http/httptest"
"testing"

"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)

func TestStripPathPrefix(t *testing.T) {
var actualPath string

app := fiber.New()

app.Use(StripPathPrefix())

app.Get("/hello/world", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
})

app.Get("/", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
})

for _, tc := range []struct {
name string
path string
prefixHeader []string
expectStatus int
expectPath string
}{
{
name: "without prefix and header",
path: "/hello/world",
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "without prefix and headers on root path",
path: "/",
expectStatus: 200,
expectPath: "/",
},
{
name: "without prefix but header",
path: "/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix but non-matching header",
path: "/prefix/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 404,
},
{
name: "with prefix and matching header",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 1st header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/", "/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 2nd header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/otherprefix/", "/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and header not ending with slash",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and non-matching header not ending with slash",
path: "/myprefix-suffix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 404,
},
{
name: "redirect when prefix does not end with a slash",
path: "/myprefix",
prefixHeader: []string{"/myprefix"},
expectStatus: 302,
expectPath: "/myprefix/",
},
} {
t.Run(tc.name, func(t *testing.T) {
actualPath = ""
req := httptest.NewRequest("GET", tc.path, nil)
if tc.prefixHeader != nil {
req.Header["X-Forwarded-Prefix"] = tc.prefixHeader
}

resp, err := app.Test(req, -1)

require.NoError(t, err)
require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code")

if tc.expectStatus == 200 {
require.Equal(t, tc.expectPath, actualPath, "rewritten path")
} else if tc.expectStatus == 302 {
require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location")
}
})
}
}
5 changes: 4 additions & 1 deletion core/http/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/gofiber/fiber/v2"
fiberhtml "github.com/gofiber/template/html/v2"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/russross/blackfriday"
)
Expand All @@ -26,7 +27,9 @@ func notFoundHandler(c *fiber.Ctx) error {
})
} else {
// The client expects an HTML response
return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{})
return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{
"BaseURL": utils.BaseURL(c),
})
}
}

Expand Down
Loading

0 comments on commit 8cdff73

Please sign in to comment.