diff --git a/model/ability.go b/model/ability.go index 5cfb994973..66efd0c110 100644 --- a/model/ability.go +++ b/model/ability.go @@ -67,11 +67,19 @@ func (channel *Channel) AddAbilities() error { abilities = append(abilities, ability) } } - return DB.Create(&abilities).Error + err := DB.Create(&abilities).Error + if err == nil { + ClearGroupModelsCacheByGroups(groups_) + } + return err } func (channel *Channel) DeleteAbilities() error { - return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error + err := DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error + if err == nil { + ClearGroupModelsCacheByGroups(strings.Split(channel.Group, ",")) + } + return err } // UpdateAbilities updates abilities of this channel. diff --git a/model/cache.go b/model/cache.go index cfb0f8a483..9e1cc99e31 100644 --- a/model/cache.go +++ b/model/cache.go @@ -167,6 +167,23 @@ func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { return models, nil } +// ClearGroupModelsCacheByGroups deletes the cached group model lists for the +// given group names so the next query fetches fresh data from the database. +func ClearGroupModelsCacheByGroups(groups []string) { + if !common.RedisEnabled { + return + } + for _, group := range groups { + if group == "" { + continue + } + err := common.RedisDel(fmt.Sprintf("group_models:%s", group)) + if err != nil { + logger.SysError(fmt.Sprintf("Redis delete group_models cache error for group %s: %s", group, err.Error())) + } + } +} + var group2model2channels map[string]map[string][]*Channel var channelSyncLock sync.RWMutex diff --git a/model/channel.go b/model/channel.go index 4b0f4b01aa..3258215b6e 100644 --- a/model/channel.go +++ b/model/channel.go @@ -3,7 +3,9 @@ package model import ( "encoding/json" "fmt" + "strings" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -196,6 +198,15 @@ func UpdateChannelStatusById(id int, status int) { if err != nil { logger.SysError("failed to update channel status: " + err.Error()) } + // Invalidate the group models cache so users immediately see the updated model list. + groupCol := "`group`" + if common.UsingPostgreSQL { + groupCol = `"group"` + } + var groupStr string + if dbErr := DB.Model(&Channel{}).Where("id = ?", id).Select(groupCol).Pluck(groupCol, &groupStr).Error; dbErr == nil && groupStr != "" { + ClearGroupModelsCacheByGroups(strings.Split(groupStr, ",")) + } } func UpdateChannelUsedQuota(id int, quota int64) {