e3d64cb76d
fix: include tiered billing models in model listing
243 lines
7.5 KiB
Go
243 lines
7.5 KiB
Go
package controller
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/constant"
|
|
"github.com/QuantumNous/new-api/dto"
|
|
"github.com/QuantumNous/new-api/model"
|
|
"github.com/QuantumNous/new-api/setting/config"
|
|
"github.com/QuantumNous/new-api/setting/operation_setting"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/glebarez/sqlite"
|
|
"github.com/stretchr/testify/require"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type listModelsResponse struct {
|
|
Success bool `json:"success"`
|
|
Data []dto.OpenAIModels `json:"data"`
|
|
Object string `json:"object"`
|
|
}
|
|
|
|
func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
|
|
t.Helper()
|
|
|
|
initModelListColumnNames(t)
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
common.UsingSQLite = true
|
|
common.UsingMySQL = false
|
|
common.UsingPostgreSQL = false
|
|
common.RedisEnabled = false
|
|
|
|
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
|
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
|
require.NoError(t, err)
|
|
model.DB = db
|
|
model.LOG_DB = db
|
|
|
|
require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{}))
|
|
|
|
t.Cleanup(func() {
|
|
sqlDB, err := db.DB()
|
|
if err == nil {
|
|
_ = sqlDB.Close()
|
|
}
|
|
})
|
|
|
|
return db
|
|
}
|
|
|
|
func initModelListColumnNames(t *testing.T) {
|
|
t.Helper()
|
|
|
|
originalIsMasterNode := common.IsMasterNode
|
|
originalSQLitePath := common.SQLitePath
|
|
originalUsingSQLite := common.UsingSQLite
|
|
originalUsingMySQL := common.UsingMySQL
|
|
originalUsingPostgreSQL := common.UsingPostgreSQL
|
|
originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN")
|
|
defer func() {
|
|
common.IsMasterNode = originalIsMasterNode
|
|
common.SQLitePath = originalSQLitePath
|
|
common.UsingSQLite = originalUsingSQLite
|
|
common.UsingMySQL = originalUsingMySQL
|
|
common.UsingPostgreSQL = originalUsingPostgreSQL
|
|
if hadSQLDSN {
|
|
require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN))
|
|
} else {
|
|
require.NoError(t, os.Unsetenv("SQL_DSN"))
|
|
}
|
|
}()
|
|
|
|
common.IsMasterNode = false
|
|
common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
|
common.UsingSQLite = false
|
|
common.UsingMySQL = false
|
|
common.UsingPostgreSQL = false
|
|
require.NoError(t, os.Setenv("SQL_DSN", "local"))
|
|
|
|
require.NoError(t, model.InitDB())
|
|
if model.DB != nil {
|
|
sqlDB, err := model.DB.DB()
|
|
if err == nil {
|
|
_ = sqlDB.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) {
|
|
t.Helper()
|
|
|
|
saved := map[string]string{}
|
|
require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
|
|
if strings.HasPrefix(key, "billing_setting.") {
|
|
saved[key] = value
|
|
}
|
|
return nil
|
|
}))
|
|
t.Cleanup(func() {
|
|
require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
|
|
model.InvalidatePricingCache()
|
|
})
|
|
|
|
modeBytes, err := common.Marshal(modes)
|
|
require.NoError(t, err)
|
|
exprBytes, err := common.Marshal(exprs)
|
|
require.NoError(t, err)
|
|
|
|
require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
|
|
"billing_setting.billing_mode": string(modeBytes),
|
|
"billing_setting.billing_expr": string(exprBytes),
|
|
}))
|
|
model.InvalidatePricingCache()
|
|
}
|
|
|
|
func withSelfUseModeDisabled(t *testing.T) {
|
|
t.Helper()
|
|
|
|
original := operation_setting.SelfUseModeEnabled
|
|
operation_setting.SelfUseModeEnabled = false
|
|
t.Cleanup(func() {
|
|
operation_setting.SelfUseModeEnabled = original
|
|
})
|
|
}
|
|
|
|
func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} {
|
|
t.Helper()
|
|
|
|
require.Equal(t, http.StatusOK, recorder.Code)
|
|
var payload listModelsResponse
|
|
require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))
|
|
require.True(t, payload.Success)
|
|
require.Equal(t, "list", payload.Object)
|
|
|
|
ids := make(map[string]struct{}, len(payload.Data))
|
|
for _, item := range payload.Data {
|
|
ids[item.Id] = struct{}{}
|
|
}
|
|
return ids
|
|
}
|
|
|
|
func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
|
|
byName := make(map[string]model.Pricing, len(pricings))
|
|
for _, pricing := range pricings {
|
|
byName[pricing.ModelName] = pricing
|
|
}
|
|
return byName
|
|
}
|
|
|
|
func TestListModelsIncludesTieredBillingModel(t *testing.T) {
|
|
withSelfUseModeDisabled(t)
|
|
withTieredBillingConfig(t, map[string]string{
|
|
"zz-tiered-visible-model": "tiered_expr",
|
|
"zz-tiered-empty-expr-model": "tiered_expr",
|
|
"zz-tiered-missing-expr-model": "tiered_expr",
|
|
}, map[string]string{
|
|
"zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
|
|
"zz-tiered-empty-expr-model": " ",
|
|
})
|
|
|
|
db := setupModelListControllerTestDB(t)
|
|
require.NoError(t, db.Create(&model.User{
|
|
Id: 1001,
|
|
Username: "model-list-user",
|
|
Password: "password",
|
|
Group: "default",
|
|
Status: common.UserStatusEnabled,
|
|
}).Error)
|
|
require.NoError(t, db.Create(&[]model.Ability{
|
|
{Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true},
|
|
{Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true},
|
|
{Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true},
|
|
{Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true},
|
|
}).Error)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
ctx, _ := gin.CreateTestContext(recorder)
|
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
|
ctx.Set("id", 1001)
|
|
|
|
ListModels(ctx, constant.ChannelTypeOpenAI)
|
|
|
|
ids := decodeListModelsResponse(t, recorder)
|
|
require.Contains(t, ids, "zz-tiered-visible-model")
|
|
require.NotContains(t, ids, "zz-tiered-empty-expr-model")
|
|
require.NotContains(t, ids, "zz-tiered-missing-expr-model")
|
|
require.NotContains(t, ids, "zz-unpriced-model")
|
|
|
|
pricingByName := pricingByModelName(model.GetPricing())
|
|
visiblePricing, ok := pricingByName["zz-tiered-visible-model"]
|
|
require.True(t, ok)
|
|
require.Equal(t, "tiered_expr", visiblePricing.BillingMode)
|
|
require.NotEmpty(t, visiblePricing.BillingExpr)
|
|
|
|
emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"]
|
|
require.True(t, ok)
|
|
require.Empty(t, emptyExprPricing.BillingMode)
|
|
require.Empty(t, emptyExprPricing.BillingExpr)
|
|
|
|
missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"]
|
|
require.True(t, ok)
|
|
require.Empty(t, missingExprPricing.BillingMode)
|
|
require.Empty(t, missingExprPricing.BillingExpr)
|
|
}
|
|
|
|
func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
|
|
withSelfUseModeDisabled(t)
|
|
withTieredBillingConfig(t, map[string]string{
|
|
"zz-token-tiered-visible-model": "tiered_expr",
|
|
"zz-token-tiered-empty-expr-model": "tiered_expr",
|
|
"zz-token-tiered-missing-expr-model": "tiered_expr",
|
|
}, map[string]string{
|
|
"zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
|
|
"zz-token-tiered-empty-expr-model": "",
|
|
})
|
|
|
|
recorder := httptest.NewRecorder()
|
|
ctx, _ := gin.CreateTestContext(recorder)
|
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
|
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true)
|
|
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{
|
|
"zz-token-tiered-visible-model": true,
|
|
"zz-token-tiered-empty-expr-model": true,
|
|
"zz-token-tiered-missing-expr-model": true,
|
|
"zz-token-unpriced-model": true,
|
|
})
|
|
|
|
ListModels(ctx, constant.ChannelTypeOpenAI)
|
|
|
|
ids := decodeListModelsResponse(t, recorder)
|
|
require.Contains(t, ids, "zz-token-tiered-visible-model")
|
|
require.NotContains(t, ids, "zz-token-tiered-empty-expr-model")
|
|
require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
|
|
require.NotContains(t, ids, "zz-token-unpriced-model")
|
|
}
|