Merge pull request #4431 from yyhhyyyyyy/fix/tiered-billing-model-list
fix: include tiered billing models in model listing
This commit is contained in:
+2
-5
@@ -17,7 +17,6 @@ import (
|
|||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/service"
|
"github.com/QuantumNous/new-api/service"
|
||||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
|
||||||
"github.com/QuantumNous/new-api/types"
|
"github.com/QuantumNous/new-api/types"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
@@ -134,8 +133,7 @@ func ListModels(c *gin.Context, modelType int) {
|
|||||||
}
|
}
|
||||||
for allowModel, _ := range tokenModelLimit {
|
for allowModel, _ := range tokenModelLimit {
|
||||||
if !acceptUnsetRatioModel {
|
if !acceptUnsetRatioModel {
|
||||||
_, _, exist := ratio_setting.GetModelRatioOrPrice(allowModel)
|
if !model.HasModelBillingConfig(allowModel) {
|
||||||
if !exist {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -182,8 +180,7 @@ func ListModels(c *gin.Context, modelType int) {
|
|||||||
}
|
}
|
||||||
for _, modelName := range models {
|
for _, modelName := range models {
|
||||||
if !acceptUnsetRatioModel {
|
if !acceptUnsetRatioModel {
|
||||||
_, _, exist := ratio_setting.GetModelRatioOrPrice(modelName)
|
if !model.HasModelBillingConfig(modelName) {
|
||||||
if !exist {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,242 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/constant"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
"github.com/QuantumNous/new-api/model"
|
||||||
|
"github.com/QuantumNous/new-api/setting/config"
|
||||||
|
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/glebarez/sqlite"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type listModelsResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data []dto.OpenAIModels `json:"data"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
initModelListColumnNames(t)
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
common.UsingSQLite = true
|
||||||
|
common.UsingMySQL = false
|
||||||
|
common.UsingPostgreSQL = false
|
||||||
|
common.RedisEnabled = false
|
||||||
|
|
||||||
|
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
||||||
|
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
model.DB = db
|
||||||
|
model.LOG_DB = db
|
||||||
|
|
||||||
|
require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{}))
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err == nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func initModelListColumnNames(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
originalIsMasterNode := common.IsMasterNode
|
||||||
|
originalSQLitePath := common.SQLitePath
|
||||||
|
originalUsingSQLite := common.UsingSQLite
|
||||||
|
originalUsingMySQL := common.UsingMySQL
|
||||||
|
originalUsingPostgreSQL := common.UsingPostgreSQL
|
||||||
|
originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN")
|
||||||
|
defer func() {
|
||||||
|
common.IsMasterNode = originalIsMasterNode
|
||||||
|
common.SQLitePath = originalSQLitePath
|
||||||
|
common.UsingSQLite = originalUsingSQLite
|
||||||
|
common.UsingMySQL = originalUsingMySQL
|
||||||
|
common.UsingPostgreSQL = originalUsingPostgreSQL
|
||||||
|
if hadSQLDSN {
|
||||||
|
require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN))
|
||||||
|
} else {
|
||||||
|
require.NoError(t, os.Unsetenv("SQL_DSN"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
common.IsMasterNode = false
|
||||||
|
common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
||||||
|
common.UsingSQLite = false
|
||||||
|
common.UsingMySQL = false
|
||||||
|
common.UsingPostgreSQL = false
|
||||||
|
require.NoError(t, os.Setenv("SQL_DSN", "local"))
|
||||||
|
|
||||||
|
require.NoError(t, model.InitDB())
|
||||||
|
if model.DB != nil {
|
||||||
|
sqlDB, err := model.DB.DB()
|
||||||
|
if err == nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
saved := map[string]string{}
|
||||||
|
require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
|
||||||
|
if strings.HasPrefix(key, "billing_setting.") {
|
||||||
|
saved[key] = value
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
|
||||||
|
model.InvalidatePricingCache()
|
||||||
|
})
|
||||||
|
|
||||||
|
modeBytes, err := common.Marshal(modes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
exprBytes, err := common.Marshal(exprs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
|
||||||
|
"billing_setting.billing_mode": string(modeBytes),
|
||||||
|
"billing_setting.billing_expr": string(exprBytes),
|
||||||
|
}))
|
||||||
|
model.InvalidatePricingCache()
|
||||||
|
}
|
||||||
|
|
||||||
|
func withSelfUseModeDisabled(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
original := operation_setting.SelfUseModeEnabled
|
||||||
|
operation_setting.SelfUseModeEnabled = false
|
||||||
|
t.Cleanup(func() {
|
||||||
|
operation_setting.SelfUseModeEnabled = original
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
var payload listModelsResponse
|
||||||
|
require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))
|
||||||
|
require.True(t, payload.Success)
|
||||||
|
require.Equal(t, "list", payload.Object)
|
||||||
|
|
||||||
|
ids := make(map[string]struct{}, len(payload.Data))
|
||||||
|
for _, item := range payload.Data {
|
||||||
|
ids[item.Id] = struct{}{}
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
|
||||||
|
byName := make(map[string]model.Pricing, len(pricings))
|
||||||
|
for _, pricing := range pricings {
|
||||||
|
byName[pricing.ModelName] = pricing
|
||||||
|
}
|
||||||
|
return byName
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListModelsIncludesTieredBillingModel(t *testing.T) {
|
||||||
|
withSelfUseModeDisabled(t)
|
||||||
|
withTieredBillingConfig(t, map[string]string{
|
||||||
|
"zz-tiered-visible-model": "tiered_expr",
|
||||||
|
"zz-tiered-empty-expr-model": "tiered_expr",
|
||||||
|
"zz-tiered-missing-expr-model": "tiered_expr",
|
||||||
|
}, map[string]string{
|
||||||
|
"zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
|
||||||
|
"zz-tiered-empty-expr-model": " ",
|
||||||
|
})
|
||||||
|
|
||||||
|
db := setupModelListControllerTestDB(t)
|
||||||
|
require.NoError(t, db.Create(&model.User{
|
||||||
|
Id: 1001,
|
||||||
|
Username: "model-list-user",
|
||||||
|
Password: "password",
|
||||||
|
Group: "default",
|
||||||
|
Status: common.UserStatusEnabled,
|
||||||
|
}).Error)
|
||||||
|
require.NoError(t, db.Create(&[]model.Ability{
|
||||||
|
{Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true},
|
||||||
|
{Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true},
|
||||||
|
{Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true},
|
||||||
|
{Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true},
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||||
|
ctx.Set("id", 1001)
|
||||||
|
|
||||||
|
ListModels(ctx, constant.ChannelTypeOpenAI)
|
||||||
|
|
||||||
|
ids := decodeListModelsResponse(t, recorder)
|
||||||
|
require.Contains(t, ids, "zz-tiered-visible-model")
|
||||||
|
require.NotContains(t, ids, "zz-tiered-empty-expr-model")
|
||||||
|
require.NotContains(t, ids, "zz-tiered-missing-expr-model")
|
||||||
|
require.NotContains(t, ids, "zz-unpriced-model")
|
||||||
|
|
||||||
|
pricingByName := pricingByModelName(model.GetPricing())
|
||||||
|
visiblePricing, ok := pricingByName["zz-tiered-visible-model"]
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "tiered_expr", visiblePricing.BillingMode)
|
||||||
|
require.NotEmpty(t, visiblePricing.BillingExpr)
|
||||||
|
|
||||||
|
emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"]
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Empty(t, emptyExprPricing.BillingMode)
|
||||||
|
require.Empty(t, emptyExprPricing.BillingExpr)
|
||||||
|
|
||||||
|
missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"]
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Empty(t, missingExprPricing.BillingMode)
|
||||||
|
require.Empty(t, missingExprPricing.BillingExpr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
|
||||||
|
withSelfUseModeDisabled(t)
|
||||||
|
withTieredBillingConfig(t, map[string]string{
|
||||||
|
"zz-token-tiered-visible-model": "tiered_expr",
|
||||||
|
"zz-token-tiered-empty-expr-model": "tiered_expr",
|
||||||
|
"zz-token-tiered-missing-expr-model": "tiered_expr",
|
||||||
|
}, map[string]string{
|
||||||
|
"zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
|
||||||
|
"zz-token-tiered-empty-expr-model": "",
|
||||||
|
})
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||||
|
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true)
|
||||||
|
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{
|
||||||
|
"zz-token-tiered-visible-model": true,
|
||||||
|
"zz-token-tiered-empty-expr-model": true,
|
||||||
|
"zz-token-tiered-missing-expr-model": true,
|
||||||
|
"zz-token-unpriced-model": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
ListModels(ctx, constant.ChannelTypeOpenAI)
|
||||||
|
|
||||||
|
ids := decodeListModelsResponse(t, recorder)
|
||||||
|
require.Contains(t, ids, "zz-token-tiered-visible-model")
|
||||||
|
require.NotContains(t, ids, "zz-token-tiered-empty-expr-model")
|
||||||
|
require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
|
||||||
|
require.NotContains(t, ids, "zz-token-unpriced-model")
|
||||||
|
}
|
||||||
@@ -578,6 +578,9 @@ func handleConfigUpdate(key, value string) bool {
|
|||||||
performance_setting.UpdateAndSync()
|
performance_setting.UpdateAndSync()
|
||||||
} else if configName == "tool_price_setting" {
|
} else if configName == "tool_price_setting" {
|
||||||
operation_setting.RebuildToolPriceIndex()
|
operation_setting.RebuildToolPriceIndex()
|
||||||
|
} else if configName == "billing_setting" {
|
||||||
|
InvalidatePricingCache()
|
||||||
|
ratio_setting.InvalidateExposedDataCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
return true // 已处理
|
return true // 已处理
|
||||||
|
|||||||
+24
-1
@@ -77,6 +77,29 @@ func GetPricing() []Pricing {
|
|||||||
return pricingMap
|
return pricingMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InvalidatePricingCache() {
|
||||||
|
updatePricingLock.Lock()
|
||||||
|
defer updatePricingLock.Unlock()
|
||||||
|
|
||||||
|
pricingMap = nil
|
||||||
|
vendorsList = nil
|
||||||
|
lastGetPricingTime = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HasModelBillingConfig(modelName string) bool {
|
||||||
|
if _, ok := ratio_setting.GetModelPrice(modelName, false); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok, _ := ratio_setting.GetModelRatio(modelName); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if billing_setting.GetBillingMode(modelName) != billing_setting.BillingModeTieredExpr {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
expr, ok := billing_setting.GetBillingExpr(modelName)
|
||||||
|
return ok && strings.TrimSpace(expr) != ""
|
||||||
|
}
|
||||||
|
|
||||||
// GetVendors 返回当前定价接口使用到的供应商信息
|
// GetVendors 返回当前定价接口使用到的供应商信息
|
||||||
func GetVendors() []PricingVendor {
|
func GetVendors() []PricingVendor {
|
||||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||||
@@ -323,7 +346,7 @@ func updatePricing() {
|
|||||||
pricing.AudioCompletionRatio = &audioCompletionRatio
|
pricing.AudioCompletionRatio = &audioCompletionRatio
|
||||||
}
|
}
|
||||||
if billingMode := billing_setting.GetBillingMode(model); billingMode == "tiered_expr" {
|
if billingMode := billing_setting.GetBillingMode(model); billingMode == "tiered_expr" {
|
||||||
if expr, ok := billing_setting.GetBillingExpr(model); ok && expr != "" {
|
if expr, ok := billing_setting.GetBillingExpr(model); ok && strings.TrimSpace(expr) != "" {
|
||||||
pricing.BillingMode = billingMode
|
pricing.BillingMode = billingMode
|
||||||
pricing.BillingExpr = expr
|
pricing.BillingExpr = expr
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-13
@@ -224,19 +224,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) (types
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ContainPriceOrRatio(modelName string) bool {
|
func ContainPriceOrRatio(modelName string) bool {
|
||||||
_, ok := ratio_setting.GetModelPrice(modelName, false)
|
return model.HasModelBillingConfig(modelName)
|
||||||
if ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
_, ok, _ = ratio_setting.GetModelRatio(modelName)
|
|
||||||
if ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr {
|
|
||||||
_, ok = billing_setting.GetBillingExpr(modelName)
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) {
|
func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user