Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/docker-push.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Deploy Sub2api to GHCR

on:
push:
branches: [ "main" ] # 只有推送到 main 分支时触发,如果是 master 请修改

env:
REGISTRY: ghcr.io
# 镜像名会自动设为:用户名/仓库名
IMAGE_NAME: ${{ github.repository }}

jobs:
build-and-push:
runs-on: ubuntu-latest # 使用 GitHub 提供的 amd64 环境构建
permissions:
contents: read
packages: write

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Log in to GHCR
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Build and push
uses: docker/build-push-action@v5
with:
context: .
push: true
# 明确指定 amd64,确保服务器能跑
platforms: linux/amd64
tags: |
ghcr.io/${{ env.IMAGE_NAME }}:latest
ghcr.io/${{ env.IMAGE_NAME }}:${{ github.sha }}
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ WORKDIR /app/frontend
RUN corepack enable && corepack prepare pnpm@latest --activate

# Install dependencies first (better caching)
COPY frontend/package.json frontend/pnpm-lock.yaml ./
COPY frontend/package.json frontend/pnpm-lock.yaml frontend/pnpm-workspace.yaml ./
RUN pnpm install --frozen-lockfile

# Copy frontend source and build
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ English | [中文](README_CN.md) | [日本語](README_JA.md)

---

## Demo
## DemoA

Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**

Expand Down
6 changes: 4 additions & 2 deletions backend/cmd/server/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion backend/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const (

// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https: blob:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"

// UMQ(用户消息队列)模式常量
const (
Expand Down
16 changes: 16 additions & 0 deletions backend/internal/handler/admin/setting_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {

AvailableChannelsEnabled: settings.AvailableChannelsEnabled,

ImageGenerationEnabled: settings.ImageGenerationEnabled,

AffiliateEnabled: settings.AffiliateEnabled,
}

Expand Down Expand Up @@ -494,6 +496,9 @@ type UpdateSettingsRequest struct {
// Available Channels feature switch (user-facing)
AvailableChannelsEnabled *bool `json:"available_channels_enabled"`

// Image Generation feature switch (user-facing)
ImageGenerationEnabled *bool `json:"image_generation_enabled"`

// Affiliate (邀请返利) feature switch
AffiliateEnabled *bool `json:"affiliate_enabled"`

Expand Down Expand Up @@ -1359,6 +1364,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.AvailableChannelsEnabled
}(),
ImageGenerationEnabled: func() bool {
if req.ImageGenerationEnabled != nil {
return *req.ImageGenerationEnabled
}
return previousSettings.ImageGenerationEnabled
}(),
AffiliateEnabled: func() bool {
if req.AffiliateEnabled != nil {
return *req.AffiliateEnabled
Expand Down Expand Up @@ -1615,6 +1626,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {

AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,

ImageGenerationEnabled: updatedSettings.ImageGenerationEnabled,

AffiliateEnabled: updatedSettings.AffiliateEnabled,
}
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
Expand Down Expand Up @@ -2001,6 +2014,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled {
changed = append(changed, "available_channels_enabled")
}
if before.ImageGenerationEnabled != after.ImageGenerationEnabled {
changed = append(changed, "image_generation_enabled")
}
if before.AffiliateEnabled != after.AffiliateEnabled {
changed = append(changed, "affiliate_enabled")
}
Expand Down
106 changes: 97 additions & 9 deletions backend/internal/handler/available_channel_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handler

import (
"sort"
"strings"

"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
Expand All @@ -24,18 +25,21 @@ type AvailableChannelHandler struct {
channelService *service.ChannelService
apiKeyService *service.APIKeyService
settingService *service.SettingService
billingService *service.BillingService
}

// NewAvailableChannelHandler 创建用户侧可用渠道 handler。
func NewAvailableChannelHandler(
channelService *service.ChannelService,
apiKeyService *service.APIKeyService,
settingService *service.SettingService,
billingService *service.BillingService,
) *AvailableChannelHandler {
return &AvailableChannelHandler{
channelService: channelService,
apiKeyService: apiKeyService,
settingService: settingService,
billingService: billingService,
}
}

Expand Down Expand Up @@ -111,6 +115,31 @@ type userAvailableChannel struct {
Platforms []userChannelPlatformSection `json:"platforms"`
}

type userModelPricingBatchRequest struct {
Models []string `json:"models"`
}

type userDefaultModelPricing struct {
Found bool `json:"found"`
BillingMode string `json:"billing_mode,omitempty"`
InputPrice *float64 `json:"input_price,omitempty"`
OutputPrice *float64 `json:"output_price,omitempty"`
CacheWritePrice *float64 `json:"cache_write_price,omitempty"`
CacheReadPrice *float64 `json:"cache_read_price,omitempty"`
ImageOutputPrice *float64 `json:"image_output_price,omitempty"`
PerRequestPrice *float64 `json:"per_request_price,omitempty"`
}

type userModelPricingBatchResponse struct {
Prices map[string]userDefaultModelPricing `json:"prices"`
}

// ListPublic 列出公开模型广场可见的「可用渠道」。
// GET /api/v1/public/channels/available
func (h *AvailableChannelHandler) ListPublic(c *gin.Context) {
h.listVisible(c, nil)
}

// List 列出当前用户可见的「可用渠道」。
// GET /api/v1/channels/available
func (h *AvailableChannelHandler) List(c *gin.Context) {
Expand All @@ -120,13 +149,6 @@ func (h *AvailableChannelHandler) List(c *gin.Context) {
return
}

// Feature 未启用时返回空数组(不暴露渠道信息)。检查放在认证之后,
// 保持与未开关前的 401 行为一致:未登录先 401,登录后再按开关决定。
if !h.featureEnabled(c) {
response.Success(c, []userAvailableChannel{})
return
}

userGroups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
Expand All @@ -136,6 +158,14 @@ func (h *AvailableChannelHandler) List(c *gin.Context) {
for i := range userGroups {
allowedGroupIDs[userGroups[i].ID] = struct{}{}
}
h.listVisible(c, allowedGroupIDs)
}

func (h *AvailableChannelHandler) listVisible(c *gin.Context, allowedGroupIDs map[int64]struct{}) {
if !h.featureEnabled(c) {
response.Success(c, []userAvailableChannel{})
return
}

channels, err := h.channelService.ListAvailable(c.Request.Context())
if err != nil {
Expand All @@ -148,7 +178,7 @@ func (h *AvailableChannelHandler) List(c *gin.Context) {
if ch.Status != service.StatusActive {
continue
}
visibleGroups := filterUserVisibleGroups(ch.Groups, allowedGroupIDs)
visibleGroups := filterVisibleGroups(ch.Groups, allowedGroupIDs)
if len(visibleGroups) == 0 {
continue
}
Expand All @@ -166,6 +196,52 @@ func (h *AvailableChannelHandler) List(c *gin.Context) {
response.Success(c, out)
}

// GetModelPricingBatch 批量查询模型默认定价。
// POST /api/v1/channels/model-pricing/batch
func (h *AvailableChannelHandler) GetModelPricingBatch(c *gin.Context) {
if h.billingService == nil {
response.InternalError(c, "Billing service not available")
return
}

var req userModelPricingBatchRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "invalid request body")
return
}

prices := make(map[string]userDefaultModelPricing, len(req.Models))
seen := make(map[string]struct{}, len(req.Models))
for _, raw := range req.Models {
model := strings.TrimSpace(raw)
if model == "" {
continue
}
key := strings.ToLower(model)
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}

pricing, err := h.billingService.GetModelPricing(model)
if err != nil {
prices[model] = userDefaultModelPricing{Found: false}
continue
}
prices[model] = userDefaultModelPricing{
Found: true,
BillingMode: string(service.BillingModeToken),
InputPrice: &pricing.InputPricePerToken,
OutputPrice: &pricing.OutputPricePerToken,
CacheWritePrice: &pricing.CacheCreationPricePerToken,
CacheReadPrice: &pricing.CacheReadPricePerToken,
ImageOutputPrice: &pricing.ImageOutputPricePerToken,
}
}

response.Success(c, userModelPricingBatchResponse{Prices: prices})
}

// buildPlatformSections 把一个渠道按 visibleGroups 的平台集合拆成有序的 section 列表:
// 每个 section 对应一个平台,只包含该平台的 groups 和 supported_models。
// 输出按 platform 字母序稳定排序,便于前端等效比较与回归测试。
Expand Down Expand Up @@ -206,10 +282,22 @@ func buildPlatformSections(
func filterUserVisibleGroups(
groups []service.AvailableGroupRef,
allowed map[int64]struct{},
) []userAvailableGroup {
return filterVisibleGroups(groups, allowed)
}

// filterVisibleGroups 过滤可见分组。allowed 为 nil 时表示匿名模型广场,只展示公开分组。
func filterVisibleGroups(
groups []service.AvailableGroupRef,
allowed map[int64]struct{},
) []userAvailableGroup {
visible := make([]userAvailableGroup, 0, len(groups))
for _, g := range groups {
if _, ok := allowed[g.ID]; !ok {
if allowed == nil {
if g.IsExclusive {
continue
}
} else if _, ok := allowed[g.ID]; !ok {
continue
}
visible = append(visible, userAvailableGroup{
Expand Down
51 changes: 51 additions & 0 deletions backend/internal/handler/available_channel_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
package handler

import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"

"github.com/gin-gonic/gin"
Expand All @@ -27,6 +30,41 @@ func TestUserAvailableChannel_Unauthenticated401(t *testing.T) {
require.Equal(t, http.StatusUnauthorized, w.Code)
}

func TestUserModelPricingBatch_ReturnsPricingForKnownAndUnknownModels(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AvailableChannelHandler{
billingService: service.NewBillingService(&config.Config{}, nil),
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/channels/model-pricing/batch",
bytes.NewBufferString(`{"models":["gpt-5.4","totally-unknown-model"]}`),
)
c.Request.Header.Set("Content-Type", "application/json")
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 1})

h.GetModelPricingBatch(c)

require.Equal(t, http.StatusOK, w.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Prices map[string]struct {
Found bool `json:"found"`
InputPrice *float64 `json:"input_price"`
} `json:"prices"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.True(t, resp.Data.Prices["gpt-5.4"].Found)
require.NotNil(t, resp.Data.Prices["gpt-5.4"].InputPrice)
require.False(t, resp.Data.Prices["totally-unknown-model"].Found)
require.Nil(t, resp.Data.Prices["totally-unknown-model"].InputPrice)
}

func TestFilterUserVisibleGroups_IntersectionOnly(t *testing.T) {
// 渠道挂在 {g1, g2, g3},用户只允许 {g1, g3} —— 响应必须仅含 g1/g3。
groups := []service.AvailableGroupRef{
Expand All @@ -42,6 +80,19 @@ func TestFilterUserVisibleGroups_IntersectionOnly(t *testing.T) {
require.ElementsMatch(t, []int64{1, 3}, ids)
}

func TestFilterVisibleGroups_PublicOnlyForAnonymous(t *testing.T) {
groups := []service.AvailableGroupRef{
{ID: 1, Name: "public", Platform: "openai", IsExclusive: false},
{ID: 2, Name: "exclusive", Platform: "openai", IsExclusive: true},
}

visible := filterVisibleGroups(groups, nil)

require.Len(t, visible, 1)
require.Equal(t, int64(1), visible[0].ID)
require.False(t, visible[0].IsExclusive)
}

func TestToUserSupportedModels_FiltersByAllowedPlatforms(t *testing.T) {
// 用户可访问分组只覆盖 anthropic;anthropic 平台的模型保留,openai 模型被剔除。
src := []service.SupportedModel{
Expand Down
Loading
Loading