diff --git a/backend/internal/handler/admin/account_codex_import.go b/backend/internal/handler/admin/account_codex_import.go index 0c599522b39..62a0ef39621 100644 --- a/backend/internal/handler/admin/account_codex_import.go +++ b/backend/internal/handler/admin/account_codex_import.go @@ -153,7 +153,7 @@ func (h *AccountHandler) importCodexSessions(ctx context.Context, req CodexSessi Items: make([]CodexSessionImportItem, 0, len(entries)), } - existingAccounts, err := h.listAccountsFiltered(ctx, service.PlatformOpenAI, service.AccountTypeOAuth, "", "", 0, "", "created_at", "desc") + existingAccounts, err := h.listAccountsFiltered(ctx, service.PlatformOpenAI, service.AccountTypeOAuth, "", "", 0, "", "", "", "", "created_at", "desc") if err != nil { return result, err } diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 00da48212aa..077463743ed 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -373,12 +373,12 @@ func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, e return out, nil } -func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string, groupID int64, privacyMode, sortBy, sortOrder string) ([]service.Account, error) { +func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode, sortBy, sortOrder string) ([]service.Account, error) { page := 1 pageSize := dataPageCap var out []service.Account for { - items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder) + items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, groupID, model, quotaStrategy, proxyFilter, privacyMode, sortBy, sortOrder) if err != nil { return nil, err } @@ -410,6 +410,9 @@ func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64, platform := c.Query("platform") accountType := c.Query("type") status := c.Query("status") + model := strings.TrimSpace(c.Query("model")) + quotaStrategy := strings.TrimSpace(c.Query("quota_strategy")) + proxyFilter := strings.TrimSpace(c.Query("proxy_filter")) privacyMode := strings.TrimSpace(c.Query("privacy_mode")) search := strings.TrimSpace(c.Query("search")) sortBy := c.DefaultQuery("sort_by", "name") @@ -431,7 +434,7 @@ func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64, } } - return h.listAccountsFiltered(ctx, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder) + return h.listAccountsFiltered(ctx, platform, accountType, status, search, groupID, model, quotaStrategy, proxyFilter, privacyMode, sortBy, sortOrder) } func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) { diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index ffab74d6a7a..d9d9b9bf34f 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -11,6 +11,7 @@ import ( "log" "log/slog" "net/http" + "net/http/httptest" "strconv" "strings" "sync" @@ -58,6 +59,7 @@ type AccountHandler struct { sessionLimitCache service.SessionLimitCache rpmCache service.RPMCache tokenCacheInvalidator service.TokenCacheInvalidator + accountTestQueue *accountTestQueue } // NewAccountHandler creates a new admin account handler @@ -90,6 +92,7 @@ func NewAccountHandler( sessionLimitCache: sessionLimitCache, rpmCache: rpmCache, tokenCacheInvalidator: tokenCacheInvalidator, + accountTestQueue: newAccountTestQueue(3 * time.Second), } } @@ -230,6 +233,9 @@ func (h *AccountHandler) List(c *gin.Context) { accountType := c.Query("type") status := c.Query("status") search := c.Query("search") + model := strings.TrimSpace(c.Query("model")) + quotaStrategy := strings.TrimSpace(c.Query("quota_strategy")) + proxyFilter := strings.TrimSpace(c.Query("proxy_filter")) privacyMode := strings.TrimSpace(c.Query("privacy_mode")) sortBy := c.DefaultQuery("sort_by", "name") sortOrder := c.DefaultQuery("sort_order", "asc") @@ -258,7 +264,7 @@ func (h *AccountHandler) List(c *gin.Context) { } } - accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder) + accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, model, quotaStrategy, proxyFilter, privacyMode, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return @@ -730,17 +736,9 @@ func (h *AccountHandler) Test(c *gin.Context) { // Allow empty body, model_id is optional _ = c.ShouldBindJSON(&req) - // Use AccountTestService to test the account with SSE streaming - if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt, req.Mode); err != nil { - // Error already sent via SSE, just log + if err := h.runQueuedInteractiveAccountTest(c, accountID, req); err != nil { return } - - if h.rateLimitService != nil { - if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil { - _ = c.Error(err) - } - } } // RecoverState handles unified recovery of recoverable account runtime state. @@ -1200,6 +1198,78 @@ func (h *AccountHandler) BatchRefresh(c *gin.Context) { }) } +// BatchTest handles batch testing account connectivity. +// POST /api/v1/admin/accounts/batch-test +func (h *AccountHandler) BatchTest(c *gin.Context) { + if h.accountTestService == nil { + response.Error(c, http.StatusServiceUnavailable, "Account test service unavailable") + return + } + + var req struct { + AccountIDs []int64 `json:"account_ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if len(req.AccountIDs) == 0 { + response.BadRequest(c, "account_ids is required") + return + } + + ctx := c.Request.Context() + accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + + foundIDs := make(map[int64]bool, len(accounts)) + for _, acc := range accounts { + if acc != nil { + foundIDs[acc.ID] = true + } + } + + successCount := 0 + failedCount := 0 + errors := make([]gin.H, 0) + + for _, id := range req.AccountIDs { + if foundIDs[id] { + continue + } + failedCount++ + errors = append(errors, gin.H{ + "account_id": id, + "error": "account not found", + }) + } + + for _, account := range accounts { + if account == nil { + continue + } + if err := h.runQueuedBackgroundAccountTest(ctx, account.ID); err != nil { + failedCount++ + errors = append(errors, gin.H{ + "account_id": account.ID, + "error": err.Error(), + }) + continue + } + successCount++ + } + + response.Success(c, gin.H{ + "total": len(req.AccountIDs), + "success": successCount, + "failed": failedCount, + "errors": errors, + }) +} + // BatchCreate handles batch creating accounts // POST /api/v1/admin/accounts/batch func (h *AccountHandler) BatchCreate(c *gin.Context) { @@ -1691,6 +1761,12 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) { response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } +func (h *AccountHandler) GetFilterModels(c *gin.Context) { + platform := strings.TrimSpace(c.Query("platform")) + groups := service.ListAccountModelFilterGroups() + response.Success(c, service.FilterAccountModelGroupsByPlatform(groups, platform)) +} + // ResetQuota handles resetting account quota usage // POST /api/v1/admin/accounts/:id/reset-quota func (h *AccountHandler) ResetQuota(c *gin.Context) { @@ -1705,15 +1781,75 @@ func (h *AccountHandler) ResetQuota(c *gin.Context) { return } + if h.rateLimitService != nil { + if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{ + InvalidateToken: true, + }); err != nil { + response.ErrorFrom(c, err) + return + } + } + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) if err != nil { response.ErrorFrom(c, err) return } + if account.Platform == service.PlatformOpenAI && h.accountTestService != nil { + if err := h.runQueuedBackgroundAccountTest(c.Request.Context(), accountID); err != nil { + log.Printf("[WARN] auto test after quota reset failed for account %d: %v", accountID, err) + } + account, err = h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + } + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } +func (h *AccountHandler) runQueuedInteractiveAccountTest(c *gin.Context, accountID int64, req TestAccountRequest) error { + if h.accountTestService == nil { + response.Error(c, http.StatusServiceUnavailable, "Account test service unavailable") + return errors.New("account test service unavailable") + } + + return h.accountTestQueue.Run(c.Request.Context(), func() error { + if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt, req.Mode); err != nil { + return err + } + if h.rateLimitService != nil { + if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil { + _ = c.Error(err) + } + } + return nil + }) +} + +func (h *AccountHandler) runQueuedBackgroundAccountTest(ctx context.Context, accountID int64) error { + if h.accountTestService == nil { + return errors.New("account test service unavailable") + } + + return h.accountTestQueue.Run(ctx, func() error { + recorder := httptest.NewRecorder() + testCtx, _ := gin.CreateTestContext(recorder) + testCtx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("/api/v1/admin/accounts/%d/test", accountID), nil).WithContext(ctx) + if err := h.accountTestService.TestAccountConnection(testCtx, accountID, "", "", ""); err != nil { + return err + } + if h.rateLimitService != nil { + if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(ctx, accountID); err != nil { + return err + } + } + return nil + }) +} + // GetTempUnschedulable handles getting temporary unschedulable status // GET /api/v1/admin/accounts/:id/temp-unschedulable func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) { @@ -2107,7 +2243,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { accounts := make([]*service.Account, 0) if len(req.AccountIDs) == 0 { - allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "", "name", "asc") + allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "", "", "", "", "name", "asc") if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/account_handler_reset_quota_test.go b/backend/internal/handler/admin/account_handler_reset_quota_test.go new file mode 100644 index 00000000000..ea228fedf67 --- /dev/null +++ b/backend/internal/handler/admin/account_handler_reset_quota_test.go @@ -0,0 +1,227 @@ +package admin + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type accountHandlerRateLimitRepoStub struct { + service.AccountRepository + accounts map[int64]*service.Account + getErr map[int64]error + clearErrs []int64 +} + +func (s *accountHandlerRateLimitRepoStub) GetByID(_ context.Context, id int64) (*service.Account, error) { + if err, ok := s.getErr[id]; ok { + return nil, err + } + if account, ok := s.accounts[id]; ok { + return account, nil + } + return nil, service.ErrAccountNotFound +} + +func (s *accountHandlerRateLimitRepoStub) ClearError(_ context.Context, id int64) error { + s.clearErrs = append(s.clearErrs, id) + if account, ok := s.accounts[id]; ok { + account.Status = service.StatusActive + account.ErrorMessage = "" + } + return nil +} + +func newAccountHandlerForResetQuotaTest(adminSvc service.AdminService, rateLimitSvc *service.RateLimitService) *AccountHandler { + return NewAccountHandler(adminSvc, nil, nil, nil, nil, rateLimitSvc, nil, nil, nil, nil, nil, nil, nil) +} + +type accountHandlerTestHTTPUpstream struct { + requestCount int +} + +func (s *accountHandlerTestHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, errors.New("unexpected Do call") +} + +func (s *accountHandlerTestHTTPUpstream) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ *tlsfingerprint.Profile) (*http.Response, error) { + s.requestCount++ + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"hi\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{}}\n\n" + + "data: [DONE]\n\n", + )), + }, nil +} + +func TestAccountHandler_ResetQuota_RecoversAccountStateAndReturnsUpdatedAccount(t *testing.T) { + gin.SetMode(gin.TestMode) + + adminSvc := newStubAdminService() + adminSvc.accounts = []service.Account{{ID: 42, Name: "before", Status: service.StatusError}} + adminSvc.getAccountByID = map[int64]*service.Account{ + 42: {ID: 42, Name: "after", Status: service.StatusActive}, + } + + repo := &accountHandlerRateLimitRepoStub{ + accounts: map[int64]*service.Account{ + 42: { + ID: 42, + Name: "before", + Status: service.StatusError, + ErrorMessage: "401", + }, + }, + } + rateLimitSvc := service.NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + handler := newAccountHandlerForResetQuotaTest(adminSvc, rateLimitSvc) + + router := gin.New() + router.POST("/api/v1/admin/accounts/:id/reset-quota", handler.ResetQuota) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/42/reset-quota", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, []int64{42}, adminSvc.resetAccountQuotaIDs) + require.Equal(t, []int64{42}, repo.clearErrs) + + var payload struct { + Code int `json:"code"` + Data struct { + ID int64 `json:"id"` + Name string `json:"name"` + Status string `json:"status"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + require.Equal(t, 0, payload.Code) + require.Equal(t, int64(42), payload.Data.ID) + require.Equal(t, "after", payload.Data.Name) + require.Equal(t, service.StatusActive, payload.Data.Status) +} + +func TestAccountHandler_ResetQuota_OpenAIAccountTriggersAutoTest(t *testing.T) { + gin.SetMode(gin.TestMode) + + adminSvc := newStubAdminService() + adminSvc.getAccountByID = map[int64]*service.Account{ + 42: {ID: 42, Name: "openai", Platform: service.PlatformOpenAI, Type: service.AccountTypeOAuth, Status: service.StatusActive}, + } + + repo := &accountHandlerRateLimitRepoStub{ + accounts: map[int64]*service.Account{ + 42: { + ID: 42, + Name: "openai", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Concurrency: 1, + Schedulable: true, + Credentials: map[string]any{"access_token": "token"}, + Extra: map[string]any{}, + ErrorMessage: "", + }, + }, + } + rateLimitSvc := service.NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + upstream := &accountHandlerTestHTTPUpstream{} + accountTestSvc := service.NewAccountTestService(repo, nil, nil, nil, upstream, &config.Config{}, &service.TLSFingerprintProfileService{}) + + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, rateLimitSvc, nil, accountTestSvc, nil, nil, nil, nil, nil) + + router := gin.New() + router.POST("/api/v1/admin/accounts/:id/reset-quota", handler.ResetQuota) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/42/reset-quota", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, upstream.requestCount) +} + +func TestAccountHandler_BatchTest_ReturnsSummary(t *testing.T) { + gin.SetMode(gin.TestMode) + + adminSvc := newStubAdminService() + adminSvc.getAccountByID = map[int64]*service.Account{ + 42: {ID: 42, Name: "openai-a", Platform: service.PlatformOpenAI, Type: service.AccountTypeOAuth, Status: service.StatusActive}, + 43: {ID: 43, Name: "openai-b", Platform: service.PlatformOpenAI, Type: service.AccountTypeOAuth, Status: service.StatusActive}, + } + + repo := &accountHandlerRateLimitRepoStub{ + accounts: map[int64]*service.Account{ + 42: { + ID: 42, + Name: "openai-a", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Concurrency: 1, + Schedulable: true, + Credentials: map[string]any{"access_token": "token-a"}, + Extra: map[string]any{}, + }, + 43: { + ID: 43, + Name: "openai-b", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Concurrency: 1, + Schedulable: true, + Credentials: map[string]any{"access_token": "token-b"}, + Extra: map[string]any{}, + }, + }, + } + rateLimitSvc := service.NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + upstream := &accountHandlerTestHTTPUpstream{} + accountTestSvc := service.NewAccountTestService(repo, nil, nil, nil, upstream, &config.Config{}, &service.TLSFingerprintProfileService{}) + + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, rateLimitSvc, nil, accountTestSvc, nil, nil, nil, nil, nil) + + router := gin.New() + router.POST("/api/v1/admin/accounts/batch-test", handler.BatchTest) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/batch-test", strings.NewReader(`{"account_ids":[42,43]}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 2, upstream.requestCount) + + var payload struct { + Code int `json:"code"` + Data struct { + Total int `json:"total"` + Success int `json:"success"` + Failed int `json:"failed"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + require.Equal(t, 0, payload.Code) + require.Equal(t, 2, payload.Data.Total) + require.Equal(t, 2, payload.Data.Success) + require.Equal(t, 0, payload.Data.Failed) +} diff --git a/backend/internal/handler/admin/account_test_queue.go b/backend/internal/handler/admin/account_test_queue.go new file mode 100644 index 00000000000..821ba1c3b29 --- /dev/null +++ b/backend/internal/handler/admin/account_test_queue.go @@ -0,0 +1,42 @@ +package admin + +import ( + "context" + "sync" + "time" +) + +type accountTestQueue struct { + mu sync.Mutex + nextAvailableAt time.Time + cooldown time.Duration +} + +func newAccountTestQueue(cooldown time.Duration) *accountTestQueue { + return &accountTestQueue{cooldown: cooldown} +} + +func (q *accountTestQueue) Run(ctx context.Context, fn func() error) error { + if q == nil { + return fn() + } + + q.mu.Lock() + defer q.mu.Unlock() + + if wait := time.Until(q.nextAvailableAt); wait > 0 { + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + } + } + + err := fn() + if q.cooldown > 0 { + q.nextAvailableAt = time.Now().Add(q.cooldown) + } + return err +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 2fef94f1561..a60334a5f93 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -24,6 +24,8 @@ type stubAdminService struct { updatedProxyIDs []int64 updatedProxies []*service.UpdateProxyInput testedProxyIDs []int64 + getAccountByID map[int64]*service.Account + resetAccountQuotaIDs []int64 createAccountErr error updateAccountErr error bulkUpdateAccountErr error @@ -34,15 +36,18 @@ type stubAdminService struct { groupIDs []int64 } lastListAccounts struct { - platform string - accountType string - status string - search string - groupID int64 - privacyMode string - sortBy string - sortOrder string - calls int + platform string + accountType string + status string + search string + groupID int64 + model string + quotaStrategy string + proxyFilter string + privacyMode string + sortBy string + sortOrder string + calls int } lastListUsers struct { page int @@ -299,12 +304,15 @@ func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, return nil } -func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) { +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) { s.lastListAccounts.platform = platform s.lastListAccounts.accountType = accountType s.lastListAccounts.status = status s.lastListAccounts.search = search s.lastListAccounts.groupID = groupID + s.lastListAccounts.model = model + s.lastListAccounts.quotaStrategy = quotaStrategy + s.lastListAccounts.proxyFilter = proxyFilter s.lastListAccounts.privacyMode = privacyMode s.lastListAccounts.sortBy = sortBy s.lastListAccounts.sortOrder = sortOrder @@ -313,6 +321,9 @@ func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, } func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) { + if account, ok := s.getAccountByID[id]; ok && account != nil { + return account, nil + } account := service.Account{ID: id, Name: "account", Status: service.StatusActive} return &account, nil } @@ -320,6 +331,10 @@ func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.A func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) { out := make([]*service.Account, 0, len(ids)) for _, id := range ids { + if account, ok := s.getAccountByID[id]; ok && account != nil { + out = append(out, account) + continue + } account := service.Account{ID: id, Name: "account", Status: service.StatusActive} out = append(out, &account) } @@ -586,6 +601,7 @@ func (s *stubAdminService) AdminResetAPIKeyRateLimitUsage(ctx context.Context, k } func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error { + s.resetAccountQuotaIDs = append(s.resetAccountQuotaIDs, id) return nil } diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 78f739ac205..01fbf686e49 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -457,10 +457,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", "", "", 0, "") + return r.ListWithFilters(ctx, params, "", "", "", "", 0, "", "", "", "") } -func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { q := r.client.Account.Query() if platform != "" { @@ -487,6 +487,27 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati )) }), ) + case service.AccountStatusFilterActiveExcludingQuotaStopped: + q = q.Where( + dbaccount.StatusEQ(service.StatusActive), + dbaccount.SchedulableEQ(true), + dbaccount.Or( + dbaccount.RateLimitResetAtIsNil(), + dbaccount.RateLimitResetAtLTE(time.Now()), + ), + dbpredicate.Account(func(s *entsql.Selector) { + col := s.C("temp_unschedulable_until") + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.LTE(col, entsql.Expr("NOW()")), + )) + }), + ) + case service.AccountStatusFilterOpenAI5HUsedZero, service.AccountStatusFilterOpenAI7DUsedZero: + q = q.Where( + dbaccount.PlatformEQ(service.PlatformOpenAI), + dbaccount.TypeEQ(service.AccountTypeOAuth), + ) case "rate_limited": q = q.Where( dbaccount.StatusEQ(service.StatusActive), @@ -533,6 +554,23 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati if search != "" { q = q.Where(dbaccount.NameContainsFold(search)) } + switch normalizedProxyFilter := strings.TrimSpace(proxyFilter); normalizedProxyFilter { + case "": + case "configured": + q = q.Where(dbaccount.ProxyIDNotNil()) + case "unconfigured": + q = q.Where(dbaccount.ProxyIDIsNil()) + default: + if !strings.HasPrefix(normalizedProxyFilter, "proxy:") { + return []service.Account{}, paginationResultFromTotal(0, params), nil + } + proxyIDText := strings.TrimSpace(strings.TrimPrefix(normalizedProxyFilter, "proxy:")) + proxyID, err := strconv.ParseInt(proxyIDText, 10, 64) + if err != nil || proxyID <= 0 { + return []service.Account{}, paginationResultFromTotal(0, params), nil + } + q = q.Where(dbaccount.ProxyIDEQ(proxyID)) + } if groupID == service.AccountListGroupUngrouped { q = q.Where(dbaccount.Not(dbaccount.HasAccountGroups())) } else if groupID > 0 { @@ -553,6 +591,46 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati })) } + normalizedStatus := strings.TrimSpace(status) + if normalizedStatus == service.AccountStatusFilterActiveExcludingQuotaStopped || + normalizedStatus == service.AccountStatusFilterOpenAI5HUsedZero || + normalizedStatus == service.AccountStatusFilterOpenAI7DUsedZero || + strings.TrimSpace(model) != "" || + strings.TrimSpace(quotaStrategy) != "" { + accountsQuery := q + for _, order := range accountListOrder(params) { + accountsQuery = accountsQuery.Order(order) + } + accounts, err := accountsQuery.All(ctx) + if err != nil { + return nil, nil, err + } + outAccounts, err := r.accountsToService(ctx, accounts) + if err != nil { + return nil, nil, err + } + filtered := make([]service.Account, 0, len(outAccounts)) + now := time.Now() + for i := range outAccounts { + matchesModel := strings.TrimSpace(model) == "" || service.IsAccountSupportedForModelFilter(&outAccounts[i], model) + if service.MatchesAccountListStatusFilter(&outAccounts[i], normalizedStatus, now) && + matchesModel && + service.MatchesOpenAIQuotaStrategyFilter(&outAccounts[i], quotaStrategy) { + filtered = append(filtered, outAccounts[i]) + } + } + total := int64(len(filtered)) + start := params.Offset() + if start >= len(filtered) { + return []service.Account{}, paginationResultFromTotal(total, params), nil + } + end := start + params.Limit() + if end > len(filtered) { + end = len(filtered) + } + return filtered[start:end], paginationResultFromTotal(total, params), nil + } + total, err := q.Count(ctx) if err != nil { return nil, nil, err diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index d1cea9eb3b0..c85a0b171c5 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -354,6 +354,41 @@ func (s *AccountRepoSuite) TestListWithFilters() { s.Require().Equal("active-temp-unsched", accounts[0].Name) }, }, + { + name: "filter_by_status_active_excluding_quota_stopped_with_empty_model_keeps_matching_accounts", + setup: func(client *dbent.Client) { + mustCreateAccount(s.T(), client, &service.Account{ + Name: "quota-ok", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Schedulable: true, + Extra: map[string]any{ + "openai_quota_strategy": "prefer_7d", + "openai_quota_stop_threshold_percent": 10, + "codex_7d_used_percent": 23, + }, + }) + mustCreateAccount(s.T(), client, &service.Account{ + Name: "quota-stopped", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Schedulable: true, + Extra: map[string]any{ + "openai_quota_strategy": "prefer_7d", + "openai_quota_stop_threshold_percent": 10, + "codex_7d_used_percent": 97, + }, + }) + }, + status: service.AccountStatusFilterActiveExcludingQuotaStopped, + model: "", + wantCount: 1, + validate: func(accounts []service.Account) { + s.Require().Equal("quota-ok", accounts[0].Name) + }, + }, { name: "filter_by_search", setup: func(client *dbent.Client) { @@ -419,7 +454,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { tt.setup(client) - accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID, tt.privacyMode) + accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID, "", "", "", tt.privacyMode) s.Require().NoError(err) s.Require().Len(accounts, tt.wantCount) if tt.validate != nil { @@ -486,7 +521,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Equal(group.ID, got.Groups[0].ID) - accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0, "") + accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0, "", "", "", "") s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(accounts, 1) diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index 6922b4c8bbd..044d9467105 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -33,7 +33,7 @@ func jitteredTTL() time.Duration { if billingCacheJitter <= 0 { return billingCacheTTL } - jitter := time.Duration(rand.IntN(int(billingCacheJitter))) + jitter := time.Duration(rand.Int64N(int64(billingCacheJitter))) return billingCacheTTL - jitter } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 27358865666..8cc44693393 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1566,7 +1566,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination return nil, nil, errors.New("not implemented") } -func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { +func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 6e1059bc829..5f63b53387f 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -279,6 +279,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts := admin.Group("/accounts") { accounts.GET("", h.Admin.Account.List) + // Keep static GET routes above "/:id" to avoid path parameters swallowing them. + accounts.GET("/filter-models", h.Admin.Account.GetFilterModels) accounts.GET("/:id", h.Admin.Account.GetByID) accounts.POST("", h.Admin.Account.Create) accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel) @@ -311,6 +313,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate) accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError) accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh) + accounts.POST("/batch-test", h.Admin.Account.BatchTest) // Antigravity 默认模型映射 accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index cd06ffa3c49..8fb83978959 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1046,6 +1046,72 @@ func (a *Account) GetOpenAISessionID() string { return strings.TrimSpace(a.GetExtraString("openai_session_id")) } +func (a *Account) GetOpenAIQuotaStrategy() string { + if !a.IsOpenAIOAuth() { + return "" + } + switch strings.TrimSpace(a.GetExtraString("openai_quota_strategy")) { + case "prefer_5h", "prefer_7d": + return strings.TrimSpace(a.GetExtraString("openai_quota_strategy")) + default: + return "" + } +} + +func (a *Account) GetOpenAIQuotaStopThresholdPercent() float64 { + if !a.IsOpenAIOAuth() { + return 0 + } + threshold := a.getExtraFloat64("openai_quota_stop_threshold_percent") + if threshold <= 0 { + return 10 + } + if threshold > 100 { + return 100 + } + return threshold +} + +func (a *Account) GetOpenAIQuotaRemainingPercentByStrategy() (float64, bool) { + if !a.IsOpenAIOAuth() || a.Extra == nil { + return 0, false + } + + var window string + switch a.GetOpenAIQuotaStrategy() { + case "prefer_5h": + window = "5h" + case "prefer_7d": + window = "7d" + default: + return 0, false + } + + progress := buildCodexUsageProgressFromExtra(a.Extra, window, time.Now()) + if progress == nil { + return 0, false + } + used := progress.Utilization + if used < 0 { + used = 0 + } + if used > 100 { + used = 100 + } + return 100 - used, true +} + +func (a *Account) IsOpenAIQuotaStrategySchedulable() bool { + if a.GetOpenAIQuotaStrategy() == "" { + return true + } + remaining, ok := a.GetOpenAIQuotaRemainingPercentByStrategy() + if !ok { + return true + } + return remaining >= a.GetOpenAIQuotaStopThresholdPercent() +} + func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool { if !a.IsOpenAI() { return false diff --git a/backend/internal/service/account_model_filter.go b/backend/internal/service/account_model_filter.go new file mode 100644 index 00000000000..275646ba109 --- /dev/null +++ b/backend/internal/service/account_model_filter.go @@ -0,0 +1,327 @@ +package service + +import ( + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" +) + +type AccountModelFilterEntry struct { + Value string `json:"value"` + Label string `json:"label"` +} + +type AccountModelFilterGroup struct { + Platform string `json:"platform"` + Label string `json:"label"` + Models []AccountModelFilterEntry `json:"models"` +} + +const ( + AccountModelFilterLimited = "__limited__" + AccountModelFilterUnlimited = "__unlimited__" + AccountStatusFilterActiveExcludingQuotaStopped = "active_excluding_quota_stopped" + AccountStatusFilterOpenAI5HUsedZero = "openai_5h_used_zero" + AccountStatusFilterOpenAI7DUsedZero = "openai_7d_used_zero" +) + +func ListAccountModelFilterGroups() []AccountModelFilterGroup { + return []AccountModelFilterGroup{ + { + Platform: PlatformOpenAI, + Label: "OpenAI", + Models: buildOpenAIModelFilterEntries(), + }, + { + Platform: PlatformAnthropic, + Label: "Anthropic", + Models: buildClaudeModelFilterEntries(), + }, + { + Platform: PlatformGemini, + Label: "Gemini", + Models: buildGeminiModelFilterEntries(), + }, + { + Platform: PlatformAntigravity, + Label: "Antigravity", + Models: buildAntigravityModelFilterEntries(), + }, + } +} + +func FilterAccountModelGroupsByPlatform(groups []AccountModelFilterGroup, platform string) []AccountModelFilterGroup { + normalizedPlatform := strings.TrimSpace(platform) + if normalizedPlatform == "" { + return groups + } + + filtered := make([]AccountModelFilterGroup, 0, 1) + for _, group := range groups { + if group.Platform == normalizedPlatform { + filtered = append(filtered, group) + break + } + } + return filtered +} + +func IsAccountSupportedForModelFilter(account *Account, requestedModel string) bool { + if account == nil { + return false + } + + trimmed := strings.TrimSpace(requestedModel) + if trimmed == "" { + return false + } + + switch trimmed { + case AccountModelFilterLimited: + return hasExplicitModelRestriction(account) + case AccountModelFilterUnlimited: + return !hasExplicitModelRestriction(account) + } + + mapping := account.GetModelMapping() + if len(mapping) == 0 { + return true + } + + if account.IsAnthropicOAuthOrSetupToken() { + for _, alias := range buildAnthropicModelAliases(trimmed) { + if mappingSupportsRequestedModel(mapping, alias) { + return true + } + } + if mappingSupportsAnthropicModelAlias(mapping, trimmed) { + return true + } + } + + return account.IsModelSupported(trimmed) +} + +func hasExplicitModelRestriction(account *Account) bool { + if account == nil || account.Credentials == nil { + return false + } + rawMapping, ok := account.Credentials["model_mapping"] + if !ok || rawMapping == nil { + return false + } + switch mapping := rawMapping.(type) { + case map[string]any: + return len(mapping) > 0 + case map[string]string: + return len(mapping) > 0 + default: + return false + } +} + +func MatchesOpenAIQuotaStrategyFilter(account *Account, requestedStrategy string) bool { + if account == nil { + return false + } + + switch strings.TrimSpace(requestedStrategy) { + case "": + return true + case "prefer_5h": + return account.GetOpenAIQuotaStrategy() == "prefer_5h" + case "prefer_7d": + return account.GetOpenAIQuotaStrategy() == "prefer_7d" + case "enabled": + strategy := account.GetOpenAIQuotaStrategy() + return strategy == "prefer_5h" || strategy == "prefer_7d" + case "disabled": + return account.GetOpenAIQuotaStrategy() == "" + default: + return false + } +} + +func MatchesAccountListStatusFilter(account *Account, requestedStatus string, now time.Time) bool { + if account == nil { + return false + } + + switch strings.TrimSpace(requestedStatus) { + case "": + return true + case StatusActive: + return matchesActiveAccountListStatusFilter(account, now, false) + case AccountStatusFilterActiveExcludingQuotaStopped: + return matchesActiveAccountListStatusFilter(account, now, true) + case "rate_limited": + return account.Status == StatusActive && + isAccountRateLimitedAt(account, now) && + !isAccountTempUnschedulableAt(account, now) + case "temp_unschedulable": + return account.Status == StatusActive && isAccountTempUnschedulableAt(account, now) + case "unschedulable": + return account.Status == StatusActive && + !account.Schedulable && + !isAccountRateLimitedAt(account, now) && + !isAccountTempUnschedulableAt(account, now) + case AccountStatusFilterOpenAI5HUsedZero: + return isOpenAIUsagePercentExactlyZero(account, "codex_5h_used_percent") + case AccountStatusFilterOpenAI7DUsedZero: + return isOpenAIUsagePercentExactlyZero(account, "codex_7d_used_percent") + default: + return account.Status == requestedStatus + } +} + +func matchesActiveAccountListStatusFilter(account *Account, now time.Time, excludeQuotaStopped bool) bool { + if account == nil { + return false + } + if account.Status != StatusActive || !account.Schedulable { + return false + } + if isAccountRateLimitedAt(account, now) || isAccountTempUnschedulableAt(account, now) { + return false + } + if excludeQuotaStopped && !account.IsOpenAIQuotaStrategySchedulable() { + return false + } + return true +} + +func isAccountRateLimitedAt(account *Account, now time.Time) bool { + return account != nil && account.RateLimitResetAt != nil && account.RateLimitResetAt.After(now) +} + +func isAccountTempUnschedulableAt(account *Account, now time.Time) bool { + return account != nil && account.TempUnschedulableUntil != nil && account.TempUnschedulableUntil.After(now) +} + +func isOpenAIUsagePercentExactlyZero(account *Account, key string) bool { + if account == nil || account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth || account.Extra == nil { + return false + } + window := "" + switch key { + case "codex_5h_used_percent": + window = "5h" + case "codex_7d_used_percent": + window = "7d" + default: + return false + } + progress := buildCodexUsageProgressFromExtra(account.Extra, window, time.Now()) + return progress != nil && progress.Utilization == 0 +} + +func buildAnthropicModelAliases(requestedModel string) []string { + trimmed := strings.TrimSpace(requestedModel) + if trimmed == "" { + return nil + } + + aliases := make([]string, 0, 8) + seen := make(map[string]struct{}, 8) + add := func(value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + if _, ok := seen[value]; ok { + return + } + seen[value] = struct{}{} + aliases = append(aliases, value) + } + + add(trimmed) + add(claude.NormalizeModelID(trimmed)) + add(claude.DenormalizeModelID(trimmed)) + + if short := stripAnthropicDateSuffix(trimmed); short != trimmed { + add(short) + add(claude.NormalizeModelID(short)) + } + + for _, model := range claude.DefaultModels { + modelID := strings.TrimSpace(model.ID) + shortID := stripAnthropicDateSuffix(modelID) + if trimmed == modelID || trimmed == shortID { + add(modelID) + add(shortID) + } + } + + return aliases +} + +func mappingSupportsAnthropicModelAlias(mapping map[string]string, requestedModel string) bool { + normalizedRequested := stripAnthropicDateSuffix(requestedModel) + if normalizedRequested == "" { + return false + } + for key, value := range mapping { + if stripAnthropicDateSuffix(key) == normalizedRequested { + return true + } + if stripAnthropicDateSuffix(value) == normalizedRequested { + return true + } + } + return false +} + +func stripAnthropicDateSuffix(model string) string { + parts := strings.Split(strings.TrimSpace(model), "-") + if len(parts) < 2 { + return strings.TrimSpace(model) + } + last := parts[len(parts)-1] + if len(last) != 8 { + return strings.TrimSpace(model) + } + for _, ch := range last { + if ch < '0' || ch > '9' { + return strings.TrimSpace(model) + } + } + return strings.Join(parts[:len(parts)-1], "-") +} + +func buildOpenAIModelFilterEntries() []AccountModelFilterEntry { + entries := make([]AccountModelFilterEntry, 0, len(openai.DefaultModels)) + for _, model := range openai.DefaultModels { + entries = append(entries, AccountModelFilterEntry{Value: model.ID, Label: model.DisplayName}) + } + return entries +} + +func buildClaudeModelFilterEntries() []AccountModelFilterEntry { + entries := make([]AccountModelFilterEntry, 0, len(claude.DefaultModels)) + for _, model := range claude.DefaultModels { + entries = append(entries, AccountModelFilterEntry{Value: model.ID, Label: model.DisplayName}) + } + return entries +} + +func buildGeminiModelFilterEntries() []AccountModelFilterEntry { + entries := make([]AccountModelFilterEntry, 0, len(geminicli.DefaultModels)) + for _, model := range geminicli.DefaultModels { + entries = append(entries, AccountModelFilterEntry{Value: model.ID, Label: model.DisplayName}) + } + return entries +} + +func buildAntigravityModelFilterEntries() []AccountModelFilterEntry { + models := antigravity.DefaultModels() + entries := make([]AccountModelFilterEntry, 0, len(models)) + for _, model := range models { + entries = append(entries, AccountModelFilterEntry{Value: model.ID, Label: model.DisplayName}) + } + return entries +} diff --git a/backend/internal/service/account_model_filter_test.go b/backend/internal/service/account_model_filter_test.go new file mode 100644 index 00000000000..fba9ec73493 --- /dev/null +++ b/backend/internal/service/account_model_filter_test.go @@ -0,0 +1,194 @@ +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestIsAccountSupportedForModelFilter(t *testing.T) { + t.Run("Anthropic OAuth 支持短 ID 命中带日期后缀模型", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-7-sonnet-20250219": "claude-3-7-sonnet-20250219", + }, + }, + } + + require.True(t, IsAccountSupportedForModelFilter(account, "claude-3-7-sonnet")) + }) +} + +func TestMatchesOpenAIQuotaStrategyFilter(t *testing.T) { + account5h := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_quota_strategy": "prefer_5h", + }, + } + account7d := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_quota_strategy": "prefer_7d", + }, + } + accountDisabled := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{}, + } + + tests := []struct { + name string + account *Account + filter string + expected bool + }{ + {name: "no_restriction", account: account5h, filter: "", expected: true}, + {name: "prefer_5h", account: account5h, filter: "prefer_5h", expected: true}, + {name: "prefer_5h_mismatch", account: account7d, filter: "prefer_5h", expected: false}, + {name: "prefer_7d", account: account7d, filter: "prefer_7d", expected: true}, + {name: "enabled_matches_5h", account: account5h, filter: "enabled", expected: true}, + {name: "enabled_matches_7d", account: account7d, filter: "enabled", expected: true}, + {name: "disabled_matches_empty", account: accountDisabled, filter: "disabled", expected: true}, + {name: "disabled_rejects_enabled", account: account5h, filter: "disabled", expected: false}, + {name: "unknown_filter", account: account5h, filter: "unknown", expected: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := MatchesOpenAIQuotaStrategyFilter(tt.account, tt.filter); got != tt.expected { + t.Fatalf("MatchesOpenAIQuotaStrategyFilter() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestMatchesModelRestrictionFilter(t *testing.T) { + accountLimited := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5", + }, + }, + } + accountUnlimited := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{}, + } + + tests := []struct { + name string + account *Account + filter string + expected bool + }{ + {name: "all_models_matches_everything", account: accountLimited, filter: "", expected: false}, + {name: "limited_matches_explicit_mapping", account: accountLimited, filter: AccountModelFilterLimited, expected: true}, + {name: "limited_rejects_missing_mapping", account: accountUnlimited, filter: AccountModelFilterLimited, expected: false}, + {name: "unlimited_matches_missing_mapping", account: accountUnlimited, filter: AccountModelFilterUnlimited, expected: true}, + {name: "unlimited_rejects_explicit_mapping", account: accountLimited, filter: AccountModelFilterUnlimited, expected: false}, + {name: "specific_model_still_uses_support_check", account: accountLimited, filter: "gpt-5", expected: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsAccountSupportedForModelFilter(tt.account, tt.filter); got != tt.expected { + t.Fatalf("IsAccountSupportedForModelFilter() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestMatchesAccountListStatusFilter(t *testing.T) { + now := time.Date(2026, 4, 26, 21, 0, 0, 0, time.UTC) + rateLimitedUntil := now.Add(5 * time.Minute) + tempUnschedUntil := now.Add(5 * time.Minute) + + activeQuotaOK := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "openai_quota_strategy": "prefer_5h", + "openai_quota_stop_threshold_percent": 10, + "codex_5h_used_percent": 80, + }, + } + activeQuotaStopped := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "openai_quota_strategy": "prefer_5h", + "openai_quota_stop_threshold_percent": 10, + "codex_5h_used_percent": 95, + }, + } + rateLimited := &Account{ + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &rateLimitedUntil, + } + tempUnsched := &Account{ + Status: StatusActive, + Schedulable: true, + TempUnschedulableUntil: &tempUnschedUntil, + } + unschedulable := &Account{ + Status: StatusActive, + Schedulable: false, + } + openAI5HZero := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_5h_used_percent": 0.0, + }, + } + openAI7DZero := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_7d_used_percent": 0.0, + }, + } + openAINonZero := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_5h_used_percent": 12.0, + "codex_7d_used_percent": 8.0, + }, + } + expired7DWindow := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_7d_used_percent": 91.0, + "codex_7d_reset_at": now.Add(-1 * time.Hour).Format(time.RFC3339), + }, + } + + require.True(t, MatchesAccountListStatusFilter(activeQuotaOK, AccountStatusFilterActiveExcludingQuotaStopped, now)) + require.False(t, MatchesAccountListStatusFilter(activeQuotaStopped, AccountStatusFilterActiveExcludingQuotaStopped, now)) + require.True(t, MatchesAccountListStatusFilter(rateLimited, "rate_limited", now)) + require.True(t, MatchesAccountListStatusFilter(tempUnsched, "temp_unschedulable", now)) + require.True(t, MatchesAccountListStatusFilter(unschedulable, "unschedulable", now)) + require.True(t, MatchesAccountListStatusFilter(openAI5HZero, AccountStatusFilterOpenAI5HUsedZero, now)) + require.True(t, MatchesAccountListStatusFilter(openAI7DZero, AccountStatusFilterOpenAI7DUsedZero, now)) + require.True(t, MatchesAccountListStatusFilter(expired7DWindow, AccountStatusFilterOpenAI7DUsedZero, now)) + require.False(t, MatchesAccountListStatusFilter(openAINonZero, AccountStatusFilterOpenAI5HUsedZero, now)) + require.False(t, MatchesAccountListStatusFilter(openAINonZero, AccountStatusFilterOpenAI7DUsedZero, now)) +} diff --git a/backend/internal/service/account_quota_schedulable_test.go b/backend/internal/service/account_quota_schedulable_test.go index 2895b34c889..ce616c6629a 100644 --- a/backend/internal/service/account_quota_schedulable_test.go +++ b/backend/internal/service/account_quota_schedulable_test.go @@ -121,3 +121,21 @@ func TestAccountIsSchedulable_QuotaExceeded(t *testing.T) { }) } } + +func TestOpenAIQuotaStrategySchedulable_ExpiredWindowTreatedAsReset(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_quota_strategy": "prefer_7d", + "openai_quota_stop_threshold_percent": 10, + "codex_7d_used_percent": 91.0, + "codex_7d_reset_at": time.Now().Add(-1 * time.Hour).Format(time.RFC3339), + }, + } + + require.True(t, account.IsOpenAIQuotaStrategySchedulable()) + remaining, ok := account.GetOpenAIQuotaRemainingPercentByStrategy() + require.True(t, ok) + require.Equal(t, 100.0, remaining) +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 3189a7290fd..f35e847a8c1 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -37,7 +37,7 @@ type AccountRepository interface { Delete(ctx context.Context, id int64) error List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string) ([]Account, *pagination.PaginationResult, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListActive(ctx context.Context) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 81169a029b0..b48c5d860b5 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination panic("unexpected List call") } -func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string) ([]Account, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 68ba8f8ce98..916fe7a6353 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -506,6 +506,16 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil { usage.SevenDay = progress } + if zeroUpdates := buildExpiredOpenAICodexZeroUpdates(account, now); len(zeroUpdates) > 0 { + mergeAccountExtra(account, zeroUpdates) + s.persistOpenAICodexProbeSnapshot(account.ID, zeroUpdates) + if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil { + usage.FiveHour = progress + } + if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil { + usage.SevenDay = progress + } + } if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 { @@ -672,6 +682,38 @@ func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, u }() } +func buildExpiredOpenAICodexZeroUpdates(account *Account, now time.Time) map[string]any { + if account == nil || account.Extra == nil { + return nil + } + + type windowConfig struct { + window string + key string + } + configs := []windowConfig{ + {window: "5h", key: "codex_5h_used_percent"}, + {window: "7d", key: "codex_7d_used_percent"}, + } + + updates := make(map[string]any) + for _, config := range configs { + progress := buildCodexUsageProgressFromExtra(account.Extra, config.window, now) + if progress == nil || progress.ResetsAt == nil || now.Before(*progress.ResetsAt) { + continue + } + current := parseExtraFloat64(account.Extra[config.key]) + if current != 0 { + updates[config.key] = 0.0 + } + } + if len(updates) == 0 { + return nil + } + updates["codex_usage_updated_at"] = now.UTC().Truncate(time.Second).Format(time.RFC3339) + return updates +} + func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) { if resp == nil { return nil, nil diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index eb5994d5498..a2c7d49b3e1 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -67,7 +67,7 @@ type AdminService interface { ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) @@ -2323,9 +2323,9 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou } // Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) { +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} - accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode) + accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, model, quotaStrategy, proxyFilter, privacyMode) if err != nil { return nil, 0, err } @@ -2737,6 +2737,9 @@ func (s *adminServiceImpl) resolveBulkUpdateTargetIDs(ctx context.Context, filte filters.Status, filters.Search, groupID, + "", + "", + "", filters.PrivacyMode, "", "", diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index df415295b1b..455c8e79552 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -88,7 +88,7 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in return nil, nil } -func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string) ([]Account, *pagination.PaginationResult, error) { s.listCalled = true s.lastListParams = params s.lastListFilters.platform = platform @@ -96,6 +96,7 @@ func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params s.lastListFilters.status = status s.lastListFilters.search = search s.lastListFilters.groupID = groupID + s.lastListFilters.model = model s.lastListFilters.privacyMode = privacyMode if s.listErr != nil { return nil, nil, s.listErr diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index 595e99e344f..21474f5b6d6 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -19,19 +19,25 @@ type accountRepoStubForAdminList struct { listWithFiltersType string listWithFiltersStatus string listWithFiltersSearch string + listWithFiltersModel string + listWithFiltersQuota string + listWithFiltersProxy string listWithFiltersPrivacy string listWithFiltersAccounts []Account listWithFiltersResult *pagination.PaginationResult listWithFiltersErr error } -func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string) ([]Account, *pagination.PaginationResult, error) { s.listWithFiltersCalls++ s.listWithFiltersParams = params s.listWithFiltersPlatform = platform s.listWithFiltersType = accountType s.listWithFiltersStatus = status s.listWithFiltersSearch = search + s.listWithFiltersModel = model + s.listWithFiltersQuota = quotaStrategy + s.listWithFiltersProxy = proxyFilter s.listWithFiltersPrivacy = privacyMode if s.listWithFiltersErr != nil { @@ -170,7 +176,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "", "name", "ASC") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "", "", "", "", "name", "ASC") require.NoError(t, err) require.Equal(t, int64(10), total) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) @@ -192,7 +198,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked, "", "") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, "", "", "", PrivacyModeCFBlocked, "", "") require.NoError(t, err) require.Equal(t, int64(1), total) require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts) diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 728328373c6..193b760203d 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 40bd1186728..887ca4c49bf 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -5,6 +5,7 @@ package service import ( "encoding/json" "fmt" + "strconv" "strings" "testing" @@ -960,7 +961,7 @@ func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) { tests := []struct { name string body string - wantMaxTokens int + wantMaxTokens int64 wantErr bool }{ { @@ -979,9 +980,14 @@ func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) { wantMaxTokens: -1, }, { - name: "超大值不 panic", - body: `{"max_tokens":9999999999999999}`, - wantMaxTokens: 10000000000000000, // float64 精度导致 9999999999999999 → 1e16 + name: "超大值不 panic", + body: `{"max_tokens":9999999999999999}`, + wantMaxTokens: func() int64 { + if strconv.IntSize == 32 { + return 0 + } + return 10000000000000000 + }(), }, { name: "null 值被忽略", @@ -998,7 +1004,7 @@ func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) { return } require.NoError(t, err) - require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens) + require.Equal(t, tt.wantMaxTokens, int64(parsed.MaxTokens)) }) } } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 5e09b95af29..ac96ca0e38e 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 08a74a37245..5231a99a58a 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -339,19 +339,19 @@ func inferGoogleOneTier(storageBytes int64) string { } if storageBytes > StorageTierUnlimited { - logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", int64(StorageTierUnlimited)) return GeminiTierGoogleAIUltra } if storageBytes >= StorageTierAIPremium { - logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", int64(StorageTierAIPremium)) return GeminiTierGoogleAIPro } if storageBytes >= StorageTierFree { - logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", int64(StorageTierFree)) return GeminiTierGoogleOneFree } - logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", int64(StorageTierFree)) return GeminiTierGoogleOneUnknown } diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index 4ee85a3a09d..26e3fd061e2 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -73,12 +73,15 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re return nil } -func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, model, quotaStrategy, proxyFilter, privacyMode string) ([]Account, *pagination.PaginationResult, error) { _ = platform _ = accountType _ = status _ = search _ = groupID + _ = model + _ = quotaStrategy + _ = proxyFilter _ = privacyMode return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil } @@ -487,7 +490,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraDoesNotSetRateLimit(t *tes } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "", "", "") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "", "", "", "", "", "") require.NoError(t, err) require.Equal(t, int64(1), total) require.Len(t, accounts, 1) diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index 69b513af830..3c00bf0de88 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ Page: page, PageSize: opsAccountsPageSize, - }, platformFilter, "", "", "", 0, "") + }, platformFilter, "", "", "", 0, "", "", "", "") if err != nil { return nil, err } diff --git a/backend/internal/service/ratelimit_session_window_test.go b/backend/internal/service/ratelimit_session_window_test.go index 7796a85e765..12d3ad40c68 100644 --- a/backend/internal/service/ratelimit_session_window_test.go +++ b/backend/internal/service/ratelimit_session_window_test.go @@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic( func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { panic("unexpected") } -func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) { +func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string, string, string, string) ([]Account, *pagination.PaginationResult, error) { panic("unexpected") } func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) { diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 00ed40878c3..1b6624497f1 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -21,6 +21,7 @@ import type { CheckMixedChannelRequest, CheckMixedChannelResponse } from '@/types' +import type { AccountModelFilterGroup } from '@/components/admin/account/accountModelFilter' /** * List all accounts with pagination @@ -37,6 +38,9 @@ export async function list( type?: string status?: string group?: string + model?: string + quota_strategy?: string + proxy_filter?: string search?: string privacy_mode?: string lite?: string @@ -72,6 +76,9 @@ export async function listWithEtag( type?: string status?: string group?: string + model?: string + quota_strategy?: string + proxy_filter?: string search?: string privacy_mode?: string lite?: string @@ -446,6 +453,13 @@ export async function getAvailableModels(id: number): Promise { return data } +export async function getFilterModels(platform?: string): Promise { + const { data } = await apiClient.get('/admin/accounts/filter-models', { + params: platform ? { platform } : undefined + }) + return data +} + export interface CRSPreviewAccount { crs_account_id: string kind: string @@ -510,6 +524,9 @@ export async function exportData(options?: { type?: string status?: string group?: string + model?: string + quota_strategy?: string + proxy_filter?: string privacy_mode?: string search?: string sort_by?: string @@ -521,11 +538,14 @@ export async function exportData(options?: { if (options?.ids && options.ids.length > 0) { params.ids = options.ids.join(',') } else if (options?.filters) { - const { platform, type, status, group, privacy_mode, search, sort_by, sort_order } = options.filters + const { platform, type, status, group, model, quota_strategy, proxy_filter, privacy_mode, search, sort_by, sort_order } = options.filters if (platform) params.platform = platform if (type) params.type = type if (status) params.status = status if (group) params.group = group + if (model) params.model = model + if (quota_strategy) params.quota_strategy = quota_strategy + if (proxy_filter) params.proxy_filter = proxy_filter if (privacy_mode) params.privacy_mode = privacy_mode if (search) params.search = search if (sort_by) params.sort_by = sort_by @@ -627,6 +647,15 @@ export async function batchRefresh(accountIds: number[]): Promise { + const { data } = await apiClient.post('/admin/accounts/batch-test', { + account_ids: accountIds, + }, { + timeout: 600000 + }) + return data +} + /** * Set privacy for an Antigravity OAuth account * @param id - Account ID @@ -660,6 +689,7 @@ export const accountsAPI = { resetTempUnschedulable, setSchedulable, getAvailableModels, + getFilterModels, generateAuthUrl, exchangeCode, refreshOpenAIToken, @@ -674,6 +704,7 @@ export const accountsAPI = { getAntigravityDefaultModelMapping, batchClearError, batchRefresh, + batchTest, setPrivacy } diff --git a/frontend/src/components/admin/account/AccountBulkActionsBar.vue b/frontend/src/components/admin/account/AccountBulkActionsBar.vue index a632bdd4213..3275d304b8f 100644 --- a/frontend/src/components/admin/account/AccountBulkActionsBar.vue +++ b/frontend/src/components/admin/account/AccountBulkActionsBar.vue @@ -26,6 +26,7 @@