diff --git a/internal/proxy/connector_lmcache.go b/internal/proxy/connector_lmcache.go index ea4a864..ee66747 100644 --- a/internal/proxy/connector_lmcache.go +++ b/internal/proxy/connector_lmcache.go @@ -49,8 +49,8 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref ctx := r.Context() preq := r.Clone(ctx) - completionRequest["max_tokens"] = 1 - completionRequest["max_completion_tokens"] = 1 + completionRequest[requestFieldMaxTokens] = 1 + completionRequest[requestFieldMaxCompletionTokens] = 1 pbody, err := json.Marshal(completionRequest) if err != nil { diff --git a/internal/proxy/connector_nixlv2.go b/internal/proxy/connector_nixlv2.go index 21d1b69..83b1a37 100644 --- a/internal/proxy/connector_nixlv2.go +++ b/internal/proxy/connector_nixlv2.go @@ -67,6 +67,7 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi streamValue, streamOk := completionRequest[requestFieldStream] streamOptionsValue, streamOptionsOk := completionRequest[requestFieldStreamOptions] maxTokensValue, maxTokensOk := completionRequest[requestFieldMaxTokens] + maxCompletionTokensValue, maxCompletionTokensOk := completionRequest[requestFieldMaxCompletionTokens] completionRequest[requestFieldKVTransferParams] = map[string]any{ requestFieldDoRemoteDecode: true, @@ -80,6 +81,7 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi completionRequest[requestFieldStream] = false delete(completionRequest, requestFieldStreamOptions) completionRequest[requestFieldMaxTokens] = 1 + completionRequest[requestFieldMaxCompletionTokens] = 1 pbody, err := json.Marshal(completionRequest) if err != nil { @@ -146,6 +148,10 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi if maxTokensOk { completionRequest[requestFieldMaxTokens] = maxTokensValue } + delete(completionRequest, requestFieldMaxCompletionTokens) + if maxCompletionTokensOk { + completionRequest[requestFieldMaxCompletionTokens] = maxCompletionTokensValue + } completionRequest[requestFieldKVTransferParams] = pKVTransferParams dbody, err := json.Marshal(completionRequest) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 68e6409..86f5041 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -39,16 +39,17 @@ const ( requestHeaderPrefillHostPort = "x-prefiller-host-port" requestHeaderRequestID = "x-request-id" - requestFieldKVTransferParams = "kv_transfer_params" - requestFieldMaxTokens = "max_tokens" - requestFieldDoRemotePrefill = "do_remote_prefill" - requestFieldDoRemoteDecode = "do_remote_decode" - requestFieldRemoteBlockIDs = "remote_block_ids" - requestFieldRemoteEngineID = "remote_engine_id" - requestFieldRemoteHost = "remote_host" - requestFieldRemotePort = "remote_port" - requestFieldStream = "stream" - requestFieldStreamOptions = "stream_options" + requestFieldKVTransferParams = "kv_transfer_params" + requestFieldMaxTokens = "max_tokens" + requestFieldMaxCompletionTokens = "max_completion_tokens" + requestFieldDoRemotePrefill = "do_remote_prefill" + requestFieldDoRemoteDecode = "do_remote_decode" + requestFieldRemoteBlockIDs = "remote_block_ids" + requestFieldRemoteEngineID = "remote_engine_id" + requestFieldRemoteHost = "remote_host" + requestFieldRemotePort = "remote_port" + requestFieldStream = "stream" + requestFieldStreamOptions = "stream_options" // ConnectorNIXLV1 enables the (now deprecated) P/D NIXL v1 protocol ConnectorNIXLV1 = "nixl"