66 "encoding/json"
77 "fmt"
88 "io"
9- "log/slog"
109 "log"
10+ "log/slog"
1111 "maps"
1212 "net/http"
1313 "net/url"
@@ -34,9 +34,20 @@ import (
3434 "github.com/chaitin/ModelKit/v2/utils"
3535)
3636
37- func ModelList (ctx context.Context , req * domain.ModelListReq , logger * slog.Logger ) (* domain.ModelListResp , error ) {
38- if logger != nil {
39- logger .Info ("ModelList req: provider=%s, baseURL=%s" , req .Provider , req .BaseURL )
37+ type ModelKit struct {
38+ logger * slog.Logger
39+ }
40+
41+ // NewModelKit 创建一个新的ModelKit实例
42+ func NewModelKit (logger * slog.Logger ) * ModelKit {
43+ return & ModelKit {
44+ logger : logger ,
45+ }
46+ }
47+
48+ func (m * ModelKit ) ModelList (ctx context.Context , req * domain.ModelListReq ) (* domain.ModelListResp , error ) {
49+ if m .logger != nil {
50+ m .logger .Info ("ModelList req:" , req .Provider , req .BaseURL )
4051 } else {
4152 log .Printf ("ModelList req: provider=%s, baseURL=%s" , req .Provider , req .BaseURL )
4253 }
@@ -69,8 +80,8 @@ func ModelList(ctx context.Context, req *domain.ModelListReq, logger *slog.Logge
6980 }
7081 defer func () {
7182 if closeErr := client .Close (); closeErr != nil {
72- if logger != nil {
73- logger .Error ("Failed to close gemini client: %v" , slog .Any ("error" , closeErr ))
83+ if m . logger != nil {
84+ m . logger .Error ("Failed to close gemini client: %v" , slog .Any ("error" , closeErr ))
7485 } else {
7586 log .Printf ("Failed to close gemini client: %v" , closeErr )
7687 }
@@ -177,9 +188,9 @@ func ModelList(ctx context.Context, req *domain.ModelListReq, logger *slog.Logge
177188 }
178189}
179190
180- func CheckModel (ctx context.Context , req * domain.CheckModelReq , logger * slog. Logger ) (* domain.CheckModelResp , error ) {
181- if logger != nil {
182- logger .Info ("CheckModel req" , "provider" , req .Provider , "model" , req .Model , "baseURL" , req .BaseURL )
191+ func ( m * ModelKit ) CheckModel (ctx context.Context , req * domain.CheckModelReq ) (* domain.CheckModelResp , error ) {
192+ if m . logger != nil {
193+ m . logger .Info ("CheckModel req" , "provider" , req .Provider , "model" , req .Model , "baseURL" , req .BaseURL )
183194 } else {
184195 log .Printf ("CheckModel req: provider=%s, model=%s, baseURL=%s" , req .Provider , req .Model , req .BaseURL )
185196 }
@@ -231,8 +242,8 @@ func CheckModel(ctx context.Context, req *domain.CheckModelReq, logger *slog.Log
231242 }
232243 defer func () {
233244 if closeErr := resp .Body .Close (); closeErr != nil {
234- if logger != nil {
235- logger .Error ("Failed to close resp body: %v" , slog .Any ("error" , closeErr ))
245+ if m . logger != nil {
246+ m . logger .Error ("Failed to close resp body: %v" , slog .Any ("error" , closeErr ))
236247 } else {
237248 log .Printf ("Failed to close resp body: %v" , closeErr )
238249 }
@@ -247,7 +258,7 @@ func CheckModel(ctx context.Context, req *domain.CheckModelReq, logger *slog.Log
247258 // end
248259 provider := consts .ParseModelProvider (req .Provider )
249260
250- resp , err := getChatModelGenerateChat (ctx , provider , modelType , req .BaseURL , req , nil )
261+ resp , err := m . getChatModelGenerateChat (ctx , provider , modelType , req .BaseURL , req )
251262 // 可编辑url的供应商,尝试修复baseURL
252263 if err != nil && (provider == consts .ModelProviderOther || provider == consts .ModelProviderOllama || provider == consts .ModelProviderAzureOpenAI ) {
253264 msg := generateBaseURLFixSuggestion (err .Error (), req .BaseURL , provider )
@@ -284,7 +295,7 @@ func CheckModel(ctx context.Context, req *domain.CheckModelReq, logger *slog.Log
284295 return checkResp , nil
285296}
286297
287- func GetChatModel (ctx context.Context , model * domain.ModelMetadata ) (model.BaseChatModel , error ) {
298+ func ( m * ModelKit ) GetChatModel (ctx context.Context , model * domain.ModelMetadata ) (model.BaseChatModel , error ) {
288299 // config chat model
289300 modelProvider := model .Provider
290301 var temperature float32 = 0.0
@@ -396,8 +407,8 @@ func ollamaListModel(baseURL string, httpClient *http.Client, apiHeader string)
396407 return request .Get [domain.ModelListResp ](client , u .Path , request .WithHeader (h ))
397408}
398409
399- func getChatModelGenerateChat (ctx context.Context , provider consts.ModelProvider , modelType consts.ModelType , baseURL string , req * domain.CheckModelReq , logger * log. Logger ) (string , error ) {
400- chatModel , err := GetChatModel (ctx , & domain.ModelMetadata {
410+ func ( m * ModelKit ) getChatModelGenerateChat (ctx context.Context , provider consts.ModelProvider , modelType consts.ModelType , baseURL string , req * domain.CheckModelReq ) (string , error ) {
411+ chatModel , err := m . GetChatModel (ctx , & domain.ModelMetadata {
401412 Provider : provider ,
402413 ModelName : req .Model ,
403414 APIKey : req .APIKey ,
@@ -416,15 +427,15 @@ func getChatModelGenerateChat(ctx context.Context, provider consts.ModelProvider
416427 })
417428 // 非流式生成失败,尝试流式生成
418429 if err != nil || genResp .Content == "" {
419- if logger != nil {
420- logger .Printf ("Generate chat failed, err: %v" , err )
430+ if m . logger != nil {
431+ m . logger .Info ("Generate chat failed" , slog . Any ( "error" , err ) )
421432 } else {
422433 log .Printf ("Generate chat failed, err: %v" , err )
423434 }
424435 streamRes , streamErr := streamCheck (ctx , & chatModel )
425436 if streamErr != nil {
426- if logger != nil {
427- logger .Printf ("Stream chat failed, err: %v" , streamErr )
437+ if m . logger != nil {
438+ m . logger .Info ("Stream chat failed" , slog . Any ( "error" , streamErr ) )
428439 } else {
429440 log .Printf ("Stream chat failed, err: %v" , streamErr )
430441 }
0 commit comments