Skip to content

Commit 2d280ba

Browse files
authored
feat(模型配置): 为ModelMetadata添加高级参数并支持各模型配置 (#77)
扩展ModelMetadata结构以支持更多高级生成参数,包括温度、top_p、max_tokens等 在GetChatModel方法中为OpenAI、DeepSeek、Gemini和Ollama实现高级参数配置
1 parent 5e7e907 commit 2d280ba

File tree

3 files changed

+118
-11
lines changed

3 files changed

+118
-11
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ go get github.com/go-playground/validator/v10
3939

4040
#### 前端依赖
4141
```bash
42-
npm install @yokowu/modelkit-ui
42+
npm install @ctzhian/modelkit
4343
#
44-
yarn add @yokowu/modelkit-ui
44+
yarn add @ctzhian/modelkit
4545
```
4646

4747
### 2. 实现接口

domain/model.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,41 @@
11
package domain
22

3-
import "github.com/chaitin/ModelKit/v2/consts"
3+
import (
4+
"github.com/chaitin/ModelKit/v2/consts"
5+
"github.com/cloudwego/eino-ext/libs/acl/openai"
6+
)
47

58
type ModelMetadata struct {
9+
// 基础参数
610
ModelName string `json:"id"` // 模型的名字
711
Object string `json:"object"` // 总是model
812
Created int `json:"created"` // 创建时间
913
Provider consts.ModelProvider `json:"provider"` // 提供商
1014
ModelType consts.ModelType `json:"model_type"` // 模型类型
11-
15+
// api 调用相关参数
1216
BaseURL string `json:"base_url"`
1317
APIKey string `json:"api_key"`
1418
APIHeader string `json:"api_header"`
1519
APIVersion string `json:"api_version"` // for azure openai
20+
// 高级参数
21+
// 限制生成的最大token数量,可选,默认为模型最大值, Ollama不支持
22+
MaxTokens *int `json:"max_tokens"`
23+
// 采样温度参数,建议与TopP二选一,范围0-2,值越大输出越随机,可选,默认1.0
24+
Temperature *float32 `json:"temperature"`
25+
// 控制采样多样性,建议与Temperature二选一,范围0-1,值越小输出越聚焦,可选,默认1.0
26+
TopP *float32 `json:"top_p"`
27+
// API停止生成的序列标记,可选,例如:[]string{"\n", "User:"}
28+
Stop []string `json:"stop"`
29+
// 基于存在惩罚重复,范围-2到2,正值增加新主题可能性,可选,默认0, Gemini不支持
30+
PresencePenalty *float32 `json:"presence_penalty"`
31+
// 指定模型响应的格式,可选,用于结构化输出, DS,Gemini,Ollama不支持
32+
ResponseFormat *openai.ChatCompletionResponseFormat `json:"response_format"`
33+
// 启用确定性采样以获得一致输出,可选,用于可重现结果, DS,Gemini不支持
34+
Seed *int `json:"seed"`
35+
// 基于频率惩罚重复,范围-2到2,正值降低重复可能性,可选,默认0, Gemini不支持
36+
FrequencyPenalty *float32 `json:"frequency_penalty"`
37+
// 修改特定token在补全中出现的可能性,可选,token ID到偏置值(-100到100)的映射, DS,Gemini,Ollama不支持
38+
LogitBias map[string]int `json:"logit_bias"`
1639
}
1740

1841
var Models []ModelMetadata

usecase/modelkit.go

Lines changed: 91 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,46 @@ func (m *ModelKit) CheckModel(ctx context.Context, req *domain.CheckModelReq) (*
318318
func (m *ModelKit) GetChatModel(ctx context.Context, model *domain.ModelMetadata) (model.BaseChatModel, error) {
319319
// config chat model
320320
modelProvider := model.Provider
321+
322+
// 使用高级参数中的温度值,如果没有设置则使用默认值0.0
321323
var temperature float32 = 0.0
324+
if model.Temperature != nil {
325+
temperature = *model.Temperature
326+
}
327+
322328
config := &openai.ChatModelConfig{
323329
APIKey: model.APIKey,
324330
BaseURL: model.BaseURL,
325331
Model: string(model.ModelName),
326332
Temperature: &temperature,
327333
}
334+
335+
// 添加高级参数支持
336+
if model.MaxTokens != nil {
337+
config.MaxTokens = model.MaxTokens
338+
}
339+
if model.TopP != nil {
340+
config.TopP = model.TopP
341+
}
342+
if len(model.Stop) > 0 {
343+
config.Stop = model.Stop
344+
}
345+
if model.PresencePenalty != nil {
346+
config.PresencePenalty = model.PresencePenalty
347+
}
348+
if model.FrequencyPenalty != nil {
349+
config.FrequencyPenalty = model.FrequencyPenalty
350+
}
351+
if model.ResponseFormat != nil {
352+
config.ResponseFormat = model.ResponseFormat
353+
}
354+
if model.Seed != nil {
355+
config.Seed = model.Seed
356+
}
357+
if model.LogitBias != nil {
358+
config.LogitBias = model.LogitBias
359+
}
360+
328361
if modelProvider == consts.ModelProviderAzureOpenAI {
329362
config.ByAzure = true
330363
config.APIVersion = model.APIVersion
@@ -341,12 +374,32 @@ func (m *ModelKit) GetChatModel(ctx context.Context, model *domain.ModelMetadata
341374

342375
switch modelProvider {
343376
case consts.ModelProviderDeepSeek:
344-
chatModel, err := deepseek.NewChatModel(ctx, &deepseek.ChatModelConfig{
377+
deepseekConfig := &deepseek.ChatModelConfig{
345378
BaseURL: model.BaseURL,
346379
APIKey: model.APIKey,
347380
Model: model.ModelName,
348381
Temperature: temperature,
349-
})
382+
}
383+
384+
// 添加 DeepSeek 支持的高级参数
385+
if model.MaxTokens != nil {
386+
deepseekConfig.MaxTokens = *model.MaxTokens
387+
}
388+
if model.TopP != nil {
389+
deepseekConfig.TopP = *model.TopP
390+
}
391+
if len(model.Stop) > 0 {
392+
deepseekConfig.Stop = model.Stop
393+
}
394+
if model.PresencePenalty != nil {
395+
deepseekConfig.PresencePenalty = *model.PresencePenalty
396+
}
397+
if model.FrequencyPenalty != nil {
398+
deepseekConfig.FrequencyPenalty = *model.FrequencyPenalty
399+
}
400+
// ResponseFormat, Seed, LogitBias 在 DeepSeek 配置中不支持,跳过
401+
402+
chatModel, err := deepseek.NewChatModel(ctx, deepseekConfig)
350403
if err != nil {
351404
return nil, err
352405
}
@@ -359,14 +412,26 @@ func (m *ModelKit) GetChatModel(ctx context.Context, model *domain.ModelMetadata
359412
return nil, err
360413
}
361414

362-
chatModel, err := gemini.NewChatModel(ctx, &gemini.Config{
415+
geminiConfig := &gemini.Config{
363416
Client: client,
364417
Model: model.ModelName,
365418
ThinkingConfig: &genai.ThinkingConfig{
366419
IncludeThoughts: true,
367420
ThinkingBudget: nil,
368421
},
369-
})
422+
}
423+
424+
// 添加 Gemini 支持的高级参数
425+
if model.MaxTokens != nil {
426+
geminiConfig.MaxTokens = model.MaxTokens
427+
}
428+
if model.Temperature != nil {
429+
geminiConfig.Temperature = model.Temperature
430+
}
431+
if model.TopP != nil {
432+
geminiConfig.TopP = model.TopP
433+
}
434+
chatModel, err := gemini.NewChatModel(ctx, geminiConfig)
370435
if err != nil {
371436
return nil, err
372437
}
@@ -385,13 +450,32 @@ func (m *ModelKit) GetChatModel(ctx context.Context, model *domain.ModelMetadata
385450
return nil, err
386451
}
387452

453+
ollamaOptions := &api.Options{
454+
Temperature: temperature,
455+
}
456+
457+
// 添加 Ollama 支持的高级参数
458+
if model.TopP != nil {
459+
ollamaOptions.TopP = *model.TopP
460+
}
461+
if len(model.Stop) > 0 {
462+
ollamaOptions.Stop = model.Stop
463+
}
464+
if model.PresencePenalty != nil {
465+
ollamaOptions.PresencePenalty = *model.PresencePenalty
466+
}
467+
if model.FrequencyPenalty != nil {
468+
ollamaOptions.FrequencyPenalty = *model.FrequencyPenalty
469+
}
470+
if model.Seed != nil {
471+
ollamaOptions.Seed = *model.Seed
472+
}
473+
388474
chatModel, err := ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
389475
BaseURL: baseUrl,
390476
Timeout: config.Timeout,
391477
Model: config.Model,
392-
Options: &api.Options{
393-
Temperature: temperature,
394-
},
478+
Options: ollamaOptions,
395479
})
396480
if err != nil {
397481
return nil, err

0 commit comments

Comments
 (0)