Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmd/llm-d-routing-sidecar/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"flag"
"net/url"
"os"
"strconv"

"k8s.io/klog/v2"

Expand All @@ -43,6 +44,7 @@ func main() {
enableSSRFProtection := flag.Bool("enable-ssrf-protection", false, "enable SSRF protection using InferencePool allowlisting")
inferencePoolNamespace := flag.String("inference-pool-namespace", os.Getenv("INFERENCE_POOL_NAMESPACE"), "the Kubernetes namespace to watch for InferencePool resources (defaults to INFERENCE_POOL_NAMESPACE env var)")
inferencePoolName := flag.String("inference-pool-name", os.Getenv("INFERENCE_POOL_NAME"), "the specific InferencePool name to watch (defaults to INFERENCE_POOL_NAME env var)")
enablePrefillerSampling := flag.Bool("enable-prefiller-sampling", func() bool { b, _ := strconv.ParseBool(os.Getenv("ENABLE_PREFILLER_SAMPLING")); return b }(), "if true, the target prefill instance will be selected randomly from among the provided prefill host values")

klog.InitFlags(nil)
flag.Parse()
Expand Down Expand Up @@ -97,6 +99,7 @@ func main() {
EnableSSRFProtection: *enableSSRFProtection,
InferencePoolNamespace: *inferencePoolNamespace,
InferencePoolName: *inferencePoolName,
EnablePrefillerSampling: *enablePrefillerSampling,
}

proxy, err := proxy.NewProxy(*port, targetURL, config)
Expand Down
38 changes: 30 additions & 8 deletions internal/proxy/chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.
package proxy

import (
"math/rand"
"net/http"
"strings"
)

var (
Expand All @@ -29,30 +31,50 @@ var (
)

func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) {
prefillPodHostPort := r.Header.Get(requestHeaderPrefillHostPort)
var prefillHostPorts []string
prefillHostPorts = r.Header.Values(requestHeaderPrefillHostPort)

if prefillPodHostPort == "" {
if len(prefillHostPorts) == 0 {
// backward compatible behavior: to remove in next release
prefillPodHostPort = r.Header.Get(requestHeaderPrefillURL)
prefillHostPorts = r.Header.Values(requestHeaderPrefillURL)
}

if prefillPodHostPort == "" {
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.2 specifies proxies
// may combine multiple header values with a comma. Accept either one host per
// header line OR one line with multiple header values.
if len(prefillHostPorts) == 1 {
prefillHostPorts = strings.Split(prefillHostPorts[0], ",")
}

numHosts := len(prefillHostPorts)
var prefillHostPort string
if numHosts > 0 {
if s.config.EnablePrefillerSampling {
// Sample a host value from the list
prefillHostPort = strings.TrimSpace(prefillHostPorts[rand.Intn(numHosts)])
} else if numHosts > 0 {
// Select only the first header value, consistent with previous behavior
prefillHostPort = strings.TrimSpace(prefillHostPorts[0])
}
}

if len(prefillHostPort) == 0 {
s.logger.V(4).Info("skip disaggregated prefill")
s.decoderProxy.ServeHTTP(w, r)
return
}

// SSRF Protection: Check if the prefill target is allowed
if !s.allowlistValidator.IsAllowed(prefillPodHostPort) {
if !s.allowlistValidator.IsAllowed(prefillHostPort) {
s.logger.Error(nil, "SSRF protection: prefill target not in allowlist",
"target", prefillPodHostPort,
"target", prefillHostPort,
"clientIP", r.RemoteAddr,
"userAgent", r.Header.Get("User-Agent"),
"requestPath", r.URL.Path)
http.Error(w, "Forbidden: prefill target not allowed by SSRF protection", http.StatusForbidden)
return
}

s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillPodHostPort)
s.runConnectorProtocol(w, r, prefillPodHostPort)
s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillHostPort)
s.runConnectorProtocol(w, r, prefillHostPort)
}
89 changes: 89 additions & 0 deletions internal/proxy/chat_completions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
Copyright 2025 The llm-d Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package proxy

import (
"net/http"
"net/http/httptest"
"slices"
"testing"
)

type mockConnectorProtocol struct {

Check failure on line 26 in internal/proxy/chat_completions_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

type mockConnectorProtocol is unused (unused)
}

func TestServer_chatCompletionsHandler(t *testing.T) {
tests := []struct {
name string
port string
sampling bool
r *http.Request

expectedCode int
expectedPrefillerIn []string
expectedPassthrough bool
}{
{r: &http.Request{}, expectedPassthrough: true},
{r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(requestHeaderPrefillHostPort): []string{}}}, expectedPassthrough: true},
{r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(requestHeaderPrefillHostPort): []string{"a"}}}, expectedPrefillerIn: []string{"a"}},
{r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(requestHeaderPrefillHostPort): []string{"a,b"}}}, expectedPrefillerIn: []string{"a"}},
{r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(requestHeaderPrefillHostPort): []string{"a,b"}}}, sampling: true, expectedPrefillerIn: []string{"a", "b"}},
{r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(requestHeaderPrefillHostPort): []string{" a, b"}}}, sampling: true, expectedPrefillerIn: []string{"a", "b"}},
{r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(requestHeaderPrefillHostPort): []string{"a,a"}}}, sampling: true, expectedPrefillerIn: []string{"a"}},
{r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(requestHeaderPrefillHostPort): []string{"a", "b"}}}, sampling: true, expectedPrefillerIn: []string{"a", "b"}},
{r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(requestHeaderPrefillHostPort): []string{""}}}, sampling: true, expectedPassthrough: true},
{r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(requestHeaderPrefillHostPort): []string{"", ""}}}, sampling: true, expectedPassthrough: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

s, err := NewProxy(tt.port, nil, Config{EnablePrefillerSampling: tt.sampling})
if err != nil {
t.Fatalf("could not construct receiver type: %v", err)
}
for i := 0; i < max(1, len(tt.expectedPrefillerIn)*3); i++ {
var prefiller string
s.runConnectorProtocol = func(w http.ResponseWriter, req *http.Request, hostPort string) { prefiller = hostPort }

Check failure on line 60 in internal/proxy/chat_completions_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

unused-parameter: parameter 'w' seems to be unused, consider removing or renaming it as _ (revive)
var passthrough bool
s.decoderProxy = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {

Check failure on line 62 in internal/proxy/chat_completions_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

unused-parameter: parameter 'w' seems to be unused, consider removing or renaming it as _ (revive)
passthrough = true
})
recorder := httptest.NewRecorder()
recorder.Code = 0
s.chatCompletionsHandler(recorder, tt.r)
if passthrough {
if !tt.expectedPassthrough {
t.Errorf("unexpected passthrough to decode")
}
if recorder.Body.Len() > 0 || recorder.Code != 0 || len(recorder.Header()) > 0 {
t.Errorf("unexpected write to response: %#v", recorder)
}
} else {
if tt.expectedPassthrough {
t.Fatal("unexpected handled request")
}
if recorder.Code != tt.expectedCode {
t.Errorf("unexpected code: %d", recorder.Code)
}
if !slices.Contains(tt.expectedPrefillerIn, prefiller) {
t.Errorf("unexpected prefiller %s", prefiller)
}
}
}
})
}
}
6 changes: 5 additions & 1 deletion internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ type Config struct {

// InferencePoolName InferencePool object name.
InferencePoolName string

// EnablePrefillerSampling configures the proxy to randomly choose from the set
// of provided prefill hosts instead of always using the first one.
EnablePrefillerSampling bool
}

type protocolRunner func(http.ResponseWriter, *http.Request, string)
Expand Down Expand Up @@ -265,7 +269,7 @@ func (s *Server) createRoutes() *http.ServeMux {
// Log errors from the decoder proxy
switch {
case errors.Is(err, syscall.ECONNREFUSED):
s.logger.Error(err, "waiting for vLLM to be ready")
s.logger.Error(err, "waiting for model server to be ready")
default:
s.logger.Error(err, "http: proxy error")
}
Expand Down
Loading