diff --git a/controller/model.go b/controller/model.go index aa6c6e2b..f2237955 100644 --- a/controller/model.go +++ b/controller/model.go @@ -17,7 +17,6 @@ import ( relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/operation_setting" - "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "github.com/samber/lo" @@ -134,8 +133,7 @@ func ListModels(c *gin.Context, modelType int) { } for allowModel, _ := range tokenModelLimit { if !acceptUnsetRatioModel { - _, _, exist := ratio_setting.GetModelRatioOrPrice(allowModel) - if !exist { + if !model.HasModelBillingConfig(allowModel) { continue } } @@ -182,8 +180,7 @@ func ListModels(c *gin.Context, modelType int) { } for _, modelName := range models { if !acceptUnsetRatioModel { - _, _, exist := ratio_setting.GetModelRatioOrPrice(modelName) - if !exist { + if !model.HasModelBillingConfig(modelName) { continue } } diff --git a/controller/model_list_test.go b/controller/model_list_test.go new file mode 100644 index 00000000..97d27cae --- /dev/null +++ b/controller/model_list_test.go @@ -0,0 +1,242 @@ +package controller + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/setting/config" + "github.com/QuantumNous/new-api/setting/operation_setting" + "github.com/gin-gonic/gin" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +type listModelsResponse struct { + Success bool `json:"success"` + Data []dto.OpenAIModels `json:"data"` + Object string `json:"object"` +} + +func setupModelListControllerTestDB(t *testing.T) *gorm.DB { + t.Helper() + + initModelListColumnNames(t) + + gin.SetMode(gin.TestMode) + common.UsingSQLite = true + common.UsingMySQL = false + common.UsingPostgreSQL = false + common.RedisEnabled = false + + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_")) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + require.NoError(t, err) + model.DB = db + model.LOG_DB = db + + require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{})) + + t.Cleanup(func() { + sqlDB, err := db.DB() + if err == nil { + _ = sqlDB.Close() + } + }) + + return db +} + +func initModelListColumnNames(t *testing.T) { + t.Helper() + + originalIsMasterNode := common.IsMasterNode + originalSQLitePath := common.SQLitePath + originalUsingSQLite := common.UsingSQLite + originalUsingMySQL := common.UsingMySQL + originalUsingPostgreSQL := common.UsingPostgreSQL + originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN") + defer func() { + common.IsMasterNode = originalIsMasterNode + common.SQLitePath = originalSQLitePath + common.UsingSQLite = originalUsingSQLite + common.UsingMySQL = originalUsingMySQL + common.UsingPostgreSQL = originalUsingPostgreSQL + if hadSQLDSN { + require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN)) + } else { + require.NoError(t, os.Unsetenv("SQL_DSN")) + } + }() + + common.IsMasterNode = false + common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_")) + common.UsingSQLite = false + common.UsingMySQL = false + common.UsingPostgreSQL = false + require.NoError(t, os.Setenv("SQL_DSN", "local")) + + require.NoError(t, model.InitDB()) + if model.DB != nil { + sqlDB, err := model.DB.DB() + if err == nil { + _ = sqlDB.Close() + } + } +} + +func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) { + t.Helper() + + saved := map[string]string{} + require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error { + if strings.HasPrefix(key, "billing_setting.") { + saved[key] = value + } + return nil + })) + t.Cleanup(func() { + require.NoError(t, config.GlobalConfig.LoadFromDB(saved)) + model.InvalidatePricingCache() + }) + + modeBytes, err := common.Marshal(modes) + require.NoError(t, err) + exprBytes, err := common.Marshal(exprs) + require.NoError(t, err) + + require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{ + "billing_setting.billing_mode": string(modeBytes), + "billing_setting.billing_expr": string(exprBytes), + })) + model.InvalidatePricingCache() +} + +func withSelfUseModeDisabled(t *testing.T) { + t.Helper() + + original := operation_setting.SelfUseModeEnabled + operation_setting.SelfUseModeEnabled = false + t.Cleanup(func() { + operation_setting.SelfUseModeEnabled = original + }) +} + +func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} { + t.Helper() + + require.Equal(t, http.StatusOK, recorder.Code) + var payload listModelsResponse + require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload)) + require.True(t, payload.Success) + require.Equal(t, "list", payload.Object) + + ids := make(map[string]struct{}, len(payload.Data)) + for _, item := range payload.Data { + ids[item.Id] = struct{}{} + } + return ids +} + +func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing { + byName := make(map[string]model.Pricing, len(pricings)) + for _, pricing := range pricings { + byName[pricing.ModelName] = pricing + } + return byName +} + +func TestListModelsIncludesTieredBillingModel(t *testing.T) { + withSelfUseModeDisabled(t) + withTieredBillingConfig(t, map[string]string{ + "zz-tiered-visible-model": "tiered_expr", + "zz-tiered-empty-expr-model": "tiered_expr", + "zz-tiered-missing-expr-model": "tiered_expr", + }, map[string]string{ + "zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`, + "zz-tiered-empty-expr-model": " ", + }) + + db := setupModelListControllerTestDB(t) + require.NoError(t, db.Create(&model.User{ + Id: 1001, + Username: "model-list-user", + Password: "password", + Group: "default", + Status: common.UserStatusEnabled, + }).Error) + require.NoError(t, db.Create(&[]model.Ability{ + {Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true}, + {Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true}, + {Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true}, + {Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true}, + }).Error) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + ctx.Set("id", 1001) + + ListModels(ctx, constant.ChannelTypeOpenAI) + + ids := decodeListModelsResponse(t, recorder) + require.Contains(t, ids, "zz-tiered-visible-model") + require.NotContains(t, ids, "zz-tiered-empty-expr-model") + require.NotContains(t, ids, "zz-tiered-missing-expr-model") + require.NotContains(t, ids, "zz-unpriced-model") + + pricingByName := pricingByModelName(model.GetPricing()) + visiblePricing, ok := pricingByName["zz-tiered-visible-model"] + require.True(t, ok) + require.Equal(t, "tiered_expr", visiblePricing.BillingMode) + require.NotEmpty(t, visiblePricing.BillingExpr) + + emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"] + require.True(t, ok) + require.Empty(t, emptyExprPricing.BillingMode) + require.Empty(t, emptyExprPricing.BillingExpr) + + missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"] + require.True(t, ok) + require.Empty(t, missingExprPricing.BillingMode) + require.Empty(t, missingExprPricing.BillingExpr) +} + +func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) { + withSelfUseModeDisabled(t) + withTieredBillingConfig(t, map[string]string{ + "zz-token-tiered-visible-model": "tiered_expr", + "zz-token-tiered-empty-expr-model": "tiered_expr", + "zz-token-tiered-missing-expr-model": "tiered_expr", + }, map[string]string{ + "zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`, + "zz-token-tiered-empty-expr-model": "", + }) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true) + common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{ + "zz-token-tiered-visible-model": true, + "zz-token-tiered-empty-expr-model": true, + "zz-token-tiered-missing-expr-model": true, + "zz-token-unpriced-model": true, + }) + + ListModels(ctx, constant.ChannelTypeOpenAI) + + ids := decodeListModelsResponse(t, recorder) + require.Contains(t, ids, "zz-token-tiered-visible-model") + require.NotContains(t, ids, "zz-token-tiered-empty-expr-model") + require.NotContains(t, ids, "zz-token-tiered-missing-expr-model") + require.NotContains(t, ids, "zz-token-unpriced-model") +} diff --git a/model/option.go b/model/option.go index ae4e5ca3..871f73a2 100644 --- a/model/option.go +++ b/model/option.go @@ -578,6 +578,9 @@ func handleConfigUpdate(key, value string) bool { performance_setting.UpdateAndSync() } else if configName == "tool_price_setting" { operation_setting.RebuildToolPriceIndex() + } else if configName == "billing_setting" { + InvalidatePricingCache() + ratio_setting.InvalidateExposedDataCache() } return true // 已处理 diff --git a/model/pricing.go b/model/pricing.go index 0fe23562..fe927585 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -77,6 +77,29 @@ func GetPricing() []Pricing { return pricingMap } +func InvalidatePricingCache() { + updatePricingLock.Lock() + defer updatePricingLock.Unlock() + + pricingMap = nil + vendorsList = nil + lastGetPricingTime = time.Time{} +} + +func HasModelBillingConfig(modelName string) bool { + if _, ok := ratio_setting.GetModelPrice(modelName, false); ok { + return true + } + if _, ok, _ := ratio_setting.GetModelRatio(modelName); ok { + return true + } + if billing_setting.GetBillingMode(modelName) != billing_setting.BillingModeTieredExpr { + return false + } + expr, ok := billing_setting.GetBillingExpr(modelName) + return ok && strings.TrimSpace(expr) != "" +} + // GetVendors 返回当前定价接口使用到的供应商信息 func GetVendors() []PricingVendor { if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { @@ -323,7 +346,7 @@ func updatePricing() { pricing.AudioCompletionRatio = &audioCompletionRatio } if billingMode := billing_setting.GetBillingMode(model); billingMode == "tiered_expr" { - if expr, ok := billing_setting.GetBillingExpr(model); ok && expr != "" { + if expr, ok := billing_setting.GetBillingExpr(model); ok && strings.TrimSpace(expr) != "" { pricing.BillingMode = billingMode pricing.BillingExpr = expr } diff --git a/relay/helper/price.go b/relay/helper/price.go index 52b971c2..b4f8e662 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -224,19 +224,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) (types } func ContainPriceOrRatio(modelName string) bool { - _, ok := ratio_setting.GetModelPrice(modelName, false) - if ok { - return true - } - _, ok, _ = ratio_setting.GetModelRatio(modelName) - if ok { - return true - } - if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr { - _, ok = billing_setting.GetBillingExpr(modelName) - return ok - } - return false + return model.HasModelBillingConfig(modelName) } func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) {