Merge pull request #4431 from yyhhyyyyyy/fix/tiered-billing-model-list

fix: include tiered billing models in model listing
This commit is contained in:
yyhhyyyyyy
2026-04-24 16:24:36 +08:00
committed by GitHub
parent 2e610e5fb3
commit e3d64cb76d
5 changed files with 272 additions and 19 deletions
+2 -5
View File
@@ -17,7 +17,6 @@ import (
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
@@ -134,8 +133,7 @@ func ListModels(c *gin.Context, modelType int) {
}
for allowModel, _ := range tokenModelLimit {
if !acceptUnsetRatioModel {
_, _, exist := ratio_setting.GetModelRatioOrPrice(allowModel)
if !exist {
if !model.HasModelBillingConfig(allowModel) {
continue
}
}
@@ -182,8 +180,7 @@ func ListModels(c *gin.Context, modelType int) {
}
for _, modelName := range models {
if !acceptUnsetRatioModel {
_, _, exist := ratio_setting.GetModelRatioOrPrice(modelName)
if !exist {
if !model.HasModelBillingConfig(modelName) {
continue
}
}
+242
View File
@@ -0,0 +1,242 @@
package controller
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/config"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
"github.com/glebarez/sqlite"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
type listModelsResponse struct {
Success bool `json:"success"`
Data []dto.OpenAIModels `json:"data"`
Object string `json:"object"`
}
func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
t.Helper()
initModelListColumnNames(t)
gin.SetMode(gin.TestMode)
common.UsingSQLite = true
common.UsingMySQL = false
common.UsingPostgreSQL = false
common.RedisEnabled = false
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
require.NoError(t, err)
model.DB = db
model.LOG_DB = db
require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{}))
t.Cleanup(func() {
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
})
return db
}
func initModelListColumnNames(t *testing.T) {
t.Helper()
originalIsMasterNode := common.IsMasterNode
originalSQLitePath := common.SQLitePath
originalUsingSQLite := common.UsingSQLite
originalUsingMySQL := common.UsingMySQL
originalUsingPostgreSQL := common.UsingPostgreSQL
originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN")
defer func() {
common.IsMasterNode = originalIsMasterNode
common.SQLitePath = originalSQLitePath
common.UsingSQLite = originalUsingSQLite
common.UsingMySQL = originalUsingMySQL
common.UsingPostgreSQL = originalUsingPostgreSQL
if hadSQLDSN {
require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN))
} else {
require.NoError(t, os.Unsetenv("SQL_DSN"))
}
}()
common.IsMasterNode = false
common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
common.UsingSQLite = false
common.UsingMySQL = false
common.UsingPostgreSQL = false
require.NoError(t, os.Setenv("SQL_DSN", "local"))
require.NoError(t, model.InitDB())
if model.DB != nil {
sqlDB, err := model.DB.DB()
if err == nil {
_ = sqlDB.Close()
}
}
}
func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) {
t.Helper()
saved := map[string]string{}
require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
if strings.HasPrefix(key, "billing_setting.") {
saved[key] = value
}
return nil
}))
t.Cleanup(func() {
require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
model.InvalidatePricingCache()
})
modeBytes, err := common.Marshal(modes)
require.NoError(t, err)
exprBytes, err := common.Marshal(exprs)
require.NoError(t, err)
require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
"billing_setting.billing_mode": string(modeBytes),
"billing_setting.billing_expr": string(exprBytes),
}))
model.InvalidatePricingCache()
}
func withSelfUseModeDisabled(t *testing.T) {
t.Helper()
original := operation_setting.SelfUseModeEnabled
operation_setting.SelfUseModeEnabled = false
t.Cleanup(func() {
operation_setting.SelfUseModeEnabled = original
})
}
func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} {
t.Helper()
require.Equal(t, http.StatusOK, recorder.Code)
var payload listModelsResponse
require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))
require.True(t, payload.Success)
require.Equal(t, "list", payload.Object)
ids := make(map[string]struct{}, len(payload.Data))
for _, item := range payload.Data {
ids[item.Id] = struct{}{}
}
return ids
}
func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
byName := make(map[string]model.Pricing, len(pricings))
for _, pricing := range pricings {
byName[pricing.ModelName] = pricing
}
return byName
}
func TestListModelsIncludesTieredBillingModel(t *testing.T) {
withSelfUseModeDisabled(t)
withTieredBillingConfig(t, map[string]string{
"zz-tiered-visible-model": "tiered_expr",
"zz-tiered-empty-expr-model": "tiered_expr",
"zz-tiered-missing-expr-model": "tiered_expr",
}, map[string]string{
"zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
"zz-tiered-empty-expr-model": " ",
})
db := setupModelListControllerTestDB(t)
require.NoError(t, db.Create(&model.User{
Id: 1001,
Username: "model-list-user",
Password: "password",
Group: "default",
Status: common.UserStatusEnabled,
}).Error)
require.NoError(t, db.Create(&[]model.Ability{
{Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true},
{Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true},
{Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true},
{Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true},
}).Error)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
ctx.Set("id", 1001)
ListModels(ctx, constant.ChannelTypeOpenAI)
ids := decodeListModelsResponse(t, recorder)
require.Contains(t, ids, "zz-tiered-visible-model")
require.NotContains(t, ids, "zz-tiered-empty-expr-model")
require.NotContains(t, ids, "zz-tiered-missing-expr-model")
require.NotContains(t, ids, "zz-unpriced-model")
pricingByName := pricingByModelName(model.GetPricing())
visiblePricing, ok := pricingByName["zz-tiered-visible-model"]
require.True(t, ok)
require.Equal(t, "tiered_expr", visiblePricing.BillingMode)
require.NotEmpty(t, visiblePricing.BillingExpr)
emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"]
require.True(t, ok)
require.Empty(t, emptyExprPricing.BillingMode)
require.Empty(t, emptyExprPricing.BillingExpr)
missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"]
require.True(t, ok)
require.Empty(t, missingExprPricing.BillingMode)
require.Empty(t, missingExprPricing.BillingExpr)
}
func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
withSelfUseModeDisabled(t)
withTieredBillingConfig(t, map[string]string{
"zz-token-tiered-visible-model": "tiered_expr",
"zz-token-tiered-empty-expr-model": "tiered_expr",
"zz-token-tiered-missing-expr-model": "tiered_expr",
}, map[string]string{
"zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
"zz-token-tiered-empty-expr-model": "",
})
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true)
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{
"zz-token-tiered-visible-model": true,
"zz-token-tiered-empty-expr-model": true,
"zz-token-tiered-missing-expr-model": true,
"zz-token-unpriced-model": true,
})
ListModels(ctx, constant.ChannelTypeOpenAI)
ids := decodeListModelsResponse(t, recorder)
require.Contains(t, ids, "zz-token-tiered-visible-model")
require.NotContains(t, ids, "zz-token-tiered-empty-expr-model")
require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
require.NotContains(t, ids, "zz-token-unpriced-model")
}
+3
View File
@@ -578,6 +578,9 @@ func handleConfigUpdate(key, value string) bool {
performance_setting.UpdateAndSync()
} else if configName == "tool_price_setting" {
operation_setting.RebuildToolPriceIndex()
} else if configName == "billing_setting" {
InvalidatePricingCache()
ratio_setting.InvalidateExposedDataCache()
}
return true // 已处理
+24 -1
View File
@@ -77,6 +77,29 @@ func GetPricing() []Pricing {
return pricingMap
}
func InvalidatePricingCache() {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
pricingMap = nil
vendorsList = nil
lastGetPricingTime = time.Time{}
}
func HasModelBillingConfig(modelName string) bool {
if _, ok := ratio_setting.GetModelPrice(modelName, false); ok {
return true
}
if _, ok, _ := ratio_setting.GetModelRatio(modelName); ok {
return true
}
if billing_setting.GetBillingMode(modelName) != billing_setting.BillingModeTieredExpr {
return false
}
expr, ok := billing_setting.GetBillingExpr(modelName)
return ok && strings.TrimSpace(expr) != ""
}
// GetVendors 返回当前定价接口使用到的供应商信息
func GetVendors() []PricingVendor {
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
@@ -323,7 +346,7 @@ func updatePricing() {
pricing.AudioCompletionRatio = &audioCompletionRatio
}
if billingMode := billing_setting.GetBillingMode(model); billingMode == "tiered_expr" {
if expr, ok := billing_setting.GetBillingExpr(model); ok && expr != "" {
if expr, ok := billing_setting.GetBillingExpr(model); ok && strings.TrimSpace(expr) != "" {
pricing.BillingMode = billingMode
pricing.BillingExpr = expr
}
+1 -13
View File
@@ -224,19 +224,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) (types
}
func ContainPriceOrRatio(modelName string) bool {
_, ok := ratio_setting.GetModelPrice(modelName, false)
if ok {
return true
}
_, ok, _ = ratio_setting.GetModelRatio(modelName)
if ok {
return true
}
if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr {
_, ok = billing_setting.GetBillingExpr(modelName)
return ok
}
return false
return model.HasModelBillingConfig(modelName)
}
func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) {