Skip to content

Commit e7151bd

Browse files
authored
Ensure max_completion_tokens=1 for prefill (#67)
Signed-off-by: Shmuel Kallner <[email protected]>
1 parent 6b9201b commit e7151bd

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

internal/proxy/connector_lmcache.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref
4949
ctx := r.Context()
5050
preq := r.Clone(ctx)
5151

52-
completionRequest["max_tokens"] = 1
53-
completionRequest["max_completion_tokens"] = 1
52+
completionRequest[requestFieldMaxTokens] = 1
53+
completionRequest[requestFieldMaxCompletionTokens] = 1
5454

5555
pbody, err := json.Marshal(completionRequest)
5656
if err != nil {

internal/proxy/connector_nixlv2.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
6767
streamValue, streamOk := completionRequest[requestFieldStream]
6868
streamOptionsValue, streamOptionsOk := completionRequest[requestFieldStreamOptions]
6969
maxTokensValue, maxTokensOk := completionRequest[requestFieldMaxTokens]
70+
maxCompletionTokensValue, maxCompletionTokensOk := completionRequest[requestFieldMaxCompletionTokens]
7071

7172
completionRequest[requestFieldKVTransferParams] = map[string]any{
7273
requestFieldDoRemoteDecode: true,
@@ -80,6 +81,7 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
8081
completionRequest[requestFieldStream] = false
8182
delete(completionRequest, requestFieldStreamOptions)
8283
completionRequest[requestFieldMaxTokens] = 1
84+
completionRequest[requestFieldMaxCompletionTokens] = 1
8385

8486
pbody, err := json.Marshal(completionRequest)
8587
if err != nil {
@@ -146,6 +148,10 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
146148
if maxTokensOk {
147149
completionRequest[requestFieldMaxTokens] = maxTokensValue
148150
}
151+
delete(completionRequest, requestFieldMaxCompletionTokens)
152+
if maxCompletionTokensOk {
153+
completionRequest[requestFieldMaxCompletionTokens] = maxCompletionTokensValue
154+
}
149155
completionRequest[requestFieldKVTransferParams] = pKVTransferParams
150156

151157
dbody, err := json.Marshal(completionRequest)

internal/proxy/proxy.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,17 @@ const (
3939
requestHeaderPrefillHostPort = "x-prefiller-host-port"
4040
requestHeaderRequestID = "x-request-id"
4141

42-
requestFieldKVTransferParams = "kv_transfer_params"
43-
requestFieldMaxTokens = "max_tokens"
44-
requestFieldDoRemotePrefill = "do_remote_prefill"
45-
requestFieldDoRemoteDecode = "do_remote_decode"
46-
requestFieldRemoteBlockIDs = "remote_block_ids"
47-
requestFieldRemoteEngineID = "remote_engine_id"
48-
requestFieldRemoteHost = "remote_host"
49-
requestFieldRemotePort = "remote_port"
50-
requestFieldStream = "stream"
51-
requestFieldStreamOptions = "stream_options"
42+
requestFieldKVTransferParams = "kv_transfer_params"
43+
requestFieldMaxTokens = "max_tokens"
44+
requestFieldMaxCompletionTokens = "max_completion_tokens"
45+
requestFieldDoRemotePrefill = "do_remote_prefill"
46+
requestFieldDoRemoteDecode = "do_remote_decode"
47+
requestFieldRemoteBlockIDs = "remote_block_ids"
48+
requestFieldRemoteEngineID = "remote_engine_id"
49+
requestFieldRemoteHost = "remote_host"
50+
requestFieldRemotePort = "remote_port"
51+
requestFieldStream = "stream"
52+
requestFieldStreamOptions = "stream_options"
5253

5354
// ConnectorNIXLV1 enables the (now deprecated) P/D NIXL v1 protocol
5455
ConnectorNIXLV1 = "nixl"

0 commit comments

Comments
 (0)