Skip to content

Commit 1143dac

Browse files
Allow the sidecar to sample from a list of prefill host ports
In some benchmarking and test environments dynamic prefill selection may be difficult and random selection among a set of hosts is sufficient. Add a new `--enable-prefiller-sampling` flag that instructs the sidecar to select a random prefill host from the provided list instead of the first one. Make the behavior opt-in to prevent users from accidentally depending on the new behavior, and keep the existing default behavior (first header value) consistent. E.g.: curl -H 'x-prefiller-host-port: server1:8000` -H 'x-prefiller-host-port: server2:8000' will randomly choose one of the two values. Signed-off-by: Clayton Coleman <[email protected]>
1 parent 3004256 commit 1143dac

File tree

3 files changed

+38
-9
lines changed

3 files changed

+38
-9
lines changed

cmd/llm-d-routing-sidecar/main.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"flag"
2121
"net/url"
2222
"os"
23+
"strconv"
2324

2425
"k8s.io/klog/v2"
2526

@@ -43,6 +44,7 @@ func main() {
4344
enableSSRFProtection := flag.Bool("enable-ssrf-protection", false, "enable SSRF protection using InferencePool allowlisting")
4445
inferencePoolNamespace := flag.String("inference-pool-namespace", os.Getenv("INFERENCE_POOL_NAMESPACE"), "the Kubernetes namespace to watch for InferencePool resources (defaults to INFERENCE_POOL_NAMESPACE env var)")
4546
inferencePoolName := flag.String("inference-pool-name", os.Getenv("INFERENCE_POOL_NAME"), "the specific InferencePool name to watch (defaults to INFERENCE_POOL_NAME env var)")
47+
enablePrefillerSampling := flag.Bool("enable-prefiller-sampling", func() bool { b, _ := strconv.ParseBool(os.Getenv("ENABLE_PREFILLER_SAMPLING")); return b }(), "if true, the target prefill instance will be selected randomly from among the provided prefill host values")
4648

4749
klog.InitFlags(nil)
4850
flag.Parse()
@@ -97,6 +99,7 @@ func main() {
9799
EnableSSRFProtection: *enableSSRFProtection,
98100
InferencePoolNamespace: *inferencePoolNamespace,
99101
InferencePoolName: *inferencePoolName,
102+
EnablePrefillerSampling: *enablePrefillerSampling,
100103
}
101104

102105
proxy, err := proxy.NewProxy(*port, targetURL, config)

internal/proxy/chat_completions.go

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ limitations under the License.
1717
package proxy
1818

1919
import (
20+
"math/rand"
2021
"net/http"
22+
"strings"
2123
)
2224

2325
var (
@@ -29,30 +31,50 @@ var (
2931
)
3032

3133
func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) {
32-
prefillPodHostPort := r.Header.Get(requestHeaderPrefillHostPort)
34+
var prefillHostPorts []string
35+
prefillHostPorts = r.Header.Values(requestHeaderPrefillHostPort)
3336

34-
if prefillPodHostPort == "" {
37+
if len(prefillHostPorts) == 0 {
3538
// backward compatible behavior: to remove in next release
36-
prefillPodHostPort = r.Header.Get(requestHeaderPrefillURL)
39+
prefillHostPorts = r.Header.Values(requestHeaderPrefillURL)
3740
}
3841

39-
if prefillPodHostPort == "" {
42+
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.2 specifies proxies
43+
// may combine multiple header values with a comma. Accept either one host per
44+
// header line OR one line with multiple header values.
45+
if len(prefillHostPorts) == 1 {
46+
prefillHostPorts = strings.Split(prefillHostPorts[0], ",")
47+
}
48+
49+
numHosts := len(prefillHostPorts)
50+
var prefillHostPort string
51+
if numHosts > 0 {
52+
if s.config.EnablePrefillerSampling {
53+
// Sample a host value from the list
54+
prefillHostPort = strings.TrimSpace(prefillHostPorts[rand.Intn(numHosts)])
55+
} else if numHosts > 0 {
56+
// Select only the first header value, consistent with previous behavior
57+
prefillHostPort = strings.TrimSpace(prefillHostPorts[0])
58+
}
59+
}
60+
61+
if len(prefillHostPort) == 0 {
4062
s.logger.V(4).Info("skip disaggregated prefill")
4163
s.decoderProxy.ServeHTTP(w, r)
4264
return
4365
}
4466

4567
// SSRF Protection: Check if the prefill target is allowed
46-
if !s.allowlistValidator.IsAllowed(prefillPodHostPort) {
68+
if !s.allowlistValidator.IsAllowed(prefillHostPort) {
4769
s.logger.Error(nil, "SSRF protection: prefill target not in allowlist",
48-
"target", prefillPodHostPort,
70+
"target", prefillHostPort,
4971
"clientIP", r.RemoteAddr,
5072
"userAgent", r.Header.Get("User-Agent"),
5173
"requestPath", r.URL.Path)
5274
http.Error(w, "Forbidden: prefill target not allowed by SSRF protection", http.StatusForbidden)
5375
return
5476
}
5577

56-
s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillPodHostPort)
57-
s.runConnectorProtocol(w, r, prefillPodHostPort)
78+
s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillHostPort)
79+
s.runConnectorProtocol(w, r, prefillHostPort)
5880
}

internal/proxy/proxy.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ type Config struct {
8888

8989
// InferencePoolName InferencePool object name.
9090
InferencePoolName string
91+
92+
// EnablePrefillerSampling configures the proxy to randomly choose from the set
93+
// of provided prefill hosts instead of always using the first one.
94+
EnablePrefillerSampling bool
9195
}
9296

9397
type protocolRunner func(http.ResponseWriter, *http.Request, string)
@@ -265,7 +269,7 @@ func (s *Server) createRoutes() *http.ServeMux {
265269
// Log errors from the decoder proxy
266270
switch {
267271
case errors.Is(err, syscall.ECONNREFUSED):
268-
s.logger.Error(err, "waiting for vLLM to be ready")
272+
s.logger.Error(err, "waiting for model server to be ready")
269273
default:
270274
s.logger.Error(err, "http: proxy error")
271275
}

0 commit comments

Comments
 (0)