Skip to content

Commit cddec53

Browse files
authored
refactor(handler): 重构ModelKit为ModelKitHandler并整合usecase依赖 (#69)
将ModelKit结构体重命名为ModelKitHandler,并添加modelkit作为依赖项。 将usecase中的ModelList和CheckModel方法重构为ModelKit结构体的成员方法。 优化日志记录方式,统一使用slog.Logger。
1 parent 04c6639 commit cddec53

File tree

2 files changed

+49
-31
lines changed

2 files changed

+49
-31
lines changed

test/backend/main.go

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ package main
22

33
import (
44
"fmt"
5-
"net/http"
65
"log/slog"
6+
"net/http"
77
"os"
88

99
"github.com/labstack/echo/v4/middleware"
@@ -13,17 +13,20 @@ import (
1313
"github.com/labstack/echo/v4"
1414
)
1515

16-
type ModelKit struct{
17-
logger *slog.Logger
16+
type ModelKitHandler struct {
17+
logger *slog.Logger
18+
modelkit *usecase.ModelKit
1819
}
1920

2021
func NewModelKit(
2122
echo *echo.Echo,
2223
logger *slog.Logger,
2324
isApmEnabled bool,
24-
) *ModelKit {
25-
m := &ModelKit{
26-
logger: logger,
25+
modelkit *usecase.ModelKit,
26+
) *ModelKitHandler {
27+
m := &ModelKitHandler{
28+
logger: logger,
29+
modelkit: modelkit,
2730
}
2831

2932
// 注册路由
@@ -34,7 +37,7 @@ func NewModelKit(
3437
return m
3538
}
3639

37-
func (p *ModelKit) GetModelList(c echo.Context) error {
40+
func (p *ModelKitHandler) GetModelList(c echo.Context) error {
3841
var req domain.ModelListReq
3942
if err := c.Bind(&req); err != nil {
4043
return c.JSON(http.StatusOK, domain.Response{
@@ -44,8 +47,7 @@ func (p *ModelKit) GetModelList(c echo.Context) error {
4447
})
4548
}
4649

47-
48-
resp, err := usecase.ModelList(c.Request().Context(), &req, p.logger)
50+
resp, err := p.modelkit.ModelList(c.Request().Context(), &req)
4951
if err != nil {
5052
fmt.Println("err:", err)
5153
return c.JSON(http.StatusOK, domain.Response{
@@ -62,7 +64,7 @@ func (p *ModelKit) GetModelList(c echo.Context) error {
6264
})
6365
}
6466

65-
func (p *ModelKit) CheckModel(c echo.Context) error {
67+
func (p *ModelKitHandler) CheckModel(c echo.Context) error {
6668
var req domain.CheckModelReq
6769
if err := c.Bind(&req); err != nil {
6870
return c.JSON(http.StatusBadRequest, domain.Response{
@@ -71,7 +73,7 @@ func (p *ModelKit) CheckModel(c echo.Context) error {
7173
})
7274
}
7375

74-
resp, err := usecase.CheckModel(c.Request().Context(), &req, p.logger)
76+
resp, err := p.modelkit.CheckModel(c.Request().Context(), &req)
7577
if err != nil {
7678
fmt.Println("err:", err)
7779
return c.JSON(http.StatusOK, domain.Response{
@@ -115,7 +117,12 @@ func main() {
115117
// 添加CORS中间件
116118
echo.Use(middleware.CORS())
117119

118-
NewModelKit(echo, logger, false)
120+
// 创建ModelKit
121+
modelkit := usecase.NewModelKit(
122+
logger,
123+
)
124+
125+
NewModelKit(echo, logger, false, modelkit)
119126

120127
err := echo.Start(":8080")
121128
if err != nil {

usecase/modelkit.go

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ import (
66
"encoding/json"
77
"fmt"
88
"io"
9-
"log/slog"
109
"log"
10+
"log/slog"
1111
"maps"
1212
"net/http"
1313
"net/url"
@@ -34,9 +34,20 @@ import (
3434
"github.com/chaitin/ModelKit/v2/utils"
3535
)
3636

37-
func ModelList(ctx context.Context, req *domain.ModelListReq, logger *slog.Logger) (*domain.ModelListResp, error) {
38-
if logger != nil {
39-
logger.Info("ModelList req: provider=%s, baseURL=%s", req.Provider, req.BaseURL)
37+
type ModelKit struct {
38+
logger *slog.Logger
39+
}
40+
41+
// NewModelKit 创建一个新的ModelKit实例
42+
func NewModelKit(logger *slog.Logger) *ModelKit {
43+
return &ModelKit{
44+
logger: logger,
45+
}
46+
}
47+
48+
func (m *ModelKit) ModelList(ctx context.Context, req *domain.ModelListReq) (*domain.ModelListResp, error) {
49+
if m.logger != nil {
50+
m.logger.Info("ModelList req:", req.Provider, req.BaseURL)
4051
} else {
4152
log.Printf("ModelList req: provider=%s, baseURL=%s", req.Provider, req.BaseURL)
4253
}
@@ -69,8 +80,8 @@ func ModelList(ctx context.Context, req *domain.ModelListReq, logger *slog.Logge
6980
}
7081
defer func() {
7182
if closeErr := client.Close(); closeErr != nil {
72-
if logger != nil {
73-
logger.Error("Failed to close gemini client: %v", slog.Any("error", closeErr))
83+
if m.logger != nil {
84+
m.logger.Error("Failed to close gemini client: %v", slog.Any("error", closeErr))
7485
} else {
7586
log.Printf("Failed to close gemini client: %v", closeErr)
7687
}
@@ -177,9 +188,9 @@ func ModelList(ctx context.Context, req *domain.ModelListReq, logger *slog.Logge
177188
}
178189
}
179190

180-
func CheckModel(ctx context.Context, req *domain.CheckModelReq, logger *slog.Logger) (*domain.CheckModelResp, error) {
181-
if logger != nil {
182-
logger.Info("CheckModel req", "provider", req.Provider, "model", req.Model, "baseURL", req.BaseURL)
191+
func (m *ModelKit) CheckModel(ctx context.Context, req *domain.CheckModelReq) (*domain.CheckModelResp, error) {
192+
if m.logger != nil {
193+
m.logger.Info("CheckModel req", "provider", req.Provider, "model", req.Model, "baseURL", req.BaseURL)
183194
} else {
184195
log.Printf("CheckModel req: provider=%s, model=%s, baseURL=%s", req.Provider, req.Model, req.BaseURL)
185196
}
@@ -231,8 +242,8 @@ func CheckModel(ctx context.Context, req *domain.CheckModelReq, logger *slog.Log
231242
}
232243
defer func() {
233244
if closeErr := resp.Body.Close(); closeErr != nil {
234-
if logger != nil {
235-
logger.Error("Failed to close resp body: %v", slog.Any("error", closeErr))
245+
if m.logger != nil {
246+
m.logger.Error("Failed to close resp body: %v", slog.Any("error", closeErr))
236247
} else {
237248
log.Printf("Failed to close resp body: %v", closeErr)
238249
}
@@ -247,7 +258,7 @@ func CheckModel(ctx context.Context, req *domain.CheckModelReq, logger *slog.Log
247258
// end
248259
provider := consts.ParseModelProvider(req.Provider)
249260

250-
resp, err := getChatModelGenerateChat(ctx, provider, modelType, req.BaseURL, req, nil)
261+
resp, err := m.getChatModelGenerateChat(ctx, provider, modelType, req.BaseURL, req)
251262
// 可编辑url的供应商,尝试修复baseURL
252263
if err != nil && (provider == consts.ModelProviderOther || provider == consts.ModelProviderOllama || provider == consts.ModelProviderAzureOpenAI) {
253264
msg := generateBaseURLFixSuggestion(err.Error(), req.BaseURL, provider)
@@ -284,7 +295,7 @@ func CheckModel(ctx context.Context, req *domain.CheckModelReq, logger *slog.Log
284295
return checkResp, nil
285296
}
286297

287-
func GetChatModel(ctx context.Context, model *domain.ModelMetadata) (model.BaseChatModel, error) {
298+
func (m *ModelKit) GetChatModel(ctx context.Context, model *domain.ModelMetadata) (model.BaseChatModel, error) {
288299
// config chat model
289300
modelProvider := model.Provider
290301
var temperature float32 = 0.0
@@ -396,8 +407,8 @@ func ollamaListModel(baseURL string, httpClient *http.Client, apiHeader string)
396407
return request.Get[domain.ModelListResp](client, u.Path, request.WithHeader(h))
397408
}
398409

399-
func getChatModelGenerateChat(ctx context.Context, provider consts.ModelProvider, modelType consts.ModelType, baseURL string, req *domain.CheckModelReq, logger *log.Logger) (string, error) {
400-
chatModel, err := GetChatModel(ctx, &domain.ModelMetadata{
410+
func (m *ModelKit) getChatModelGenerateChat(ctx context.Context, provider consts.ModelProvider, modelType consts.ModelType, baseURL string, req *domain.CheckModelReq) (string, error) {
411+
chatModel, err := m.GetChatModel(ctx, &domain.ModelMetadata{
401412
Provider: provider,
402413
ModelName: req.Model,
403414
APIKey: req.APIKey,
@@ -416,15 +427,15 @@ func getChatModelGenerateChat(ctx context.Context, provider consts.ModelProvider
416427
})
417428
// 非流式生成失败,尝试流式生成
418429
if err != nil || genResp.Content == "" {
419-
if logger != nil {
420-
logger.Printf("Generate chat failed, err: %v", err)
430+
if m.logger != nil {
431+
m.logger.Info("Generate chat failed", slog.Any("error", err))
421432
} else {
422433
log.Printf("Generate chat failed, err: %v", err)
423434
}
424435
streamRes, streamErr := streamCheck(ctx, &chatModel)
425436
if streamErr != nil {
426-
if logger != nil {
427-
logger.Printf("Stream chat failed, err: %v", streamErr)
437+
if m.logger != nil {
438+
m.logger.Info("Stream chat failed", slog.Any("error", streamErr))
428439
} else {
429440
log.Printf("Stream chat failed, err: %v", streamErr)
430441
}

0 commit comments

Comments
 (0)