package controller import ( "fmt" "net/http" "net/http/httptest" "os" "strings" "testing" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/setting/config" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/gin-gonic/gin" "github.com/glebarez/sqlite" "github.com/stretchr/testify/require" "gorm.io/gorm" ) type listModelsResponse struct { Success bool `json:"success"` Data []dto.OpenAIModels `json:"data"` Object string `json:"object"` } func setupModelListControllerTestDB(t *testing.T) *gorm.DB { t.Helper() initModelListColumnNames(t) gin.SetMode(gin.TestMode) common.UsingSQLite = true common.UsingMySQL = false common.UsingPostgreSQL = false common.RedisEnabled = false dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_")) db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) require.NoError(t, err) model.DB = db model.LOG_DB = db require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{})) t.Cleanup(func() { sqlDB, err := db.DB() if err == nil { _ = sqlDB.Close() } }) return db } func initModelListColumnNames(t *testing.T) { t.Helper() originalIsMasterNode := common.IsMasterNode originalSQLitePath := common.SQLitePath originalUsingSQLite := common.UsingSQLite originalUsingMySQL := common.UsingMySQL originalUsingPostgreSQL := common.UsingPostgreSQL originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN") defer func() { common.IsMasterNode = originalIsMasterNode common.SQLitePath = originalSQLitePath common.UsingSQLite = originalUsingSQLite common.UsingMySQL = originalUsingMySQL common.UsingPostgreSQL = originalUsingPostgreSQL if hadSQLDSN { require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN)) } else { require.NoError(t, os.Unsetenv("SQL_DSN")) } }() common.IsMasterNode = false common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_")) common.UsingSQLite = false common.UsingMySQL = false common.UsingPostgreSQL = false require.NoError(t, os.Setenv("SQL_DSN", "local")) require.NoError(t, model.InitDB()) if model.DB != nil { sqlDB, err := model.DB.DB() if err == nil { _ = sqlDB.Close() } } } func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) { t.Helper() saved := map[string]string{} require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error { if strings.HasPrefix(key, "billing_setting.") { saved[key] = value } return nil })) t.Cleanup(func() { require.NoError(t, config.GlobalConfig.LoadFromDB(saved)) model.InvalidatePricingCache() }) modeBytes, err := common.Marshal(modes) require.NoError(t, err) exprBytes, err := common.Marshal(exprs) require.NoError(t, err) require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{ "billing_setting.billing_mode": string(modeBytes), "billing_setting.billing_expr": string(exprBytes), })) model.InvalidatePricingCache() } func withSelfUseModeDisabled(t *testing.T) { t.Helper() original := operation_setting.SelfUseModeEnabled operation_setting.SelfUseModeEnabled = false t.Cleanup(func() { operation_setting.SelfUseModeEnabled = original }) } func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} { t.Helper() require.Equal(t, http.StatusOK, recorder.Code) var payload listModelsResponse require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload)) require.True(t, payload.Success) require.Equal(t, "list", payload.Object) ids := make(map[string]struct{}, len(payload.Data)) for _, item := range payload.Data { ids[item.Id] = struct{}{} } return ids } func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing { byName := make(map[string]model.Pricing, len(pricings)) for _, pricing := range pricings { byName[pricing.ModelName] = pricing } return byName } func TestListModelsIncludesTieredBillingModel(t *testing.T) { withSelfUseModeDisabled(t) withTieredBillingConfig(t, map[string]string{ "zz-tiered-visible-model": "tiered_expr", "zz-tiered-empty-expr-model": "tiered_expr", "zz-tiered-missing-expr-model": "tiered_expr", }, map[string]string{ "zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`, "zz-tiered-empty-expr-model": " ", }) db := setupModelListControllerTestDB(t) require.NoError(t, db.Create(&model.User{ Id: 1001, Username: "model-list-user", Password: "password", Group: "default", Status: common.UserStatusEnabled, }).Error) require.NoError(t, db.Create(&[]model.Ability{ {Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true}, {Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true}, {Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true}, {Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true}, }).Error) recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) ctx.Set("id", 1001) ListModels(ctx, constant.ChannelTypeOpenAI) ids := decodeListModelsResponse(t, recorder) require.Contains(t, ids, "zz-tiered-visible-model") require.NotContains(t, ids, "zz-tiered-empty-expr-model") require.NotContains(t, ids, "zz-tiered-missing-expr-model") require.NotContains(t, ids, "zz-unpriced-model") pricingByName := pricingByModelName(model.GetPricing()) visiblePricing, ok := pricingByName["zz-tiered-visible-model"] require.True(t, ok) require.Equal(t, "tiered_expr", visiblePricing.BillingMode) require.NotEmpty(t, visiblePricing.BillingExpr) emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"] require.True(t, ok) require.Empty(t, emptyExprPricing.BillingMode) require.Empty(t, emptyExprPricing.BillingExpr) missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"] require.True(t, ok) require.Empty(t, missingExprPricing.BillingMode) require.Empty(t, missingExprPricing.BillingExpr) } func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) { withSelfUseModeDisabled(t) withTieredBillingConfig(t, map[string]string{ "zz-token-tiered-visible-model": "tiered_expr", "zz-token-tiered-empty-expr-model": "tiered_expr", "zz-token-tiered-missing-expr-model": "tiered_expr", }, map[string]string{ "zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`, "zz-token-tiered-empty-expr-model": "", }) recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true) common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{ "zz-token-tiered-visible-model": true, "zz-token-tiered-empty-expr-model": true, "zz-token-tiered-missing-expr-model": true, "zz-token-unpriced-model": true, }) ListModels(ctx, constant.ChannelTypeOpenAI) ids := decodeListModelsResponse(t, recorder) require.Contains(t, ids, "zz-token-tiered-visible-model") require.NotContains(t, ids, "zz-token-tiered-empty-expr-model") require.NotContains(t, ids, "zz-token-tiered-missing-expr-model") require.NotContains(t, ids, "zz-token-unpriced-model") }