From 0220df84291c127f49df80c7e1cfcbf4f3ffde49 Mon Sep 17 00:00:00 2001 From: yyhhyyyyyy Date: Thu, 9 Apr 2026 17:08:52 +0800 Subject: [PATCH] fix(channel-test): support tiered billing model tests (#4145) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pre-fill BillingRequestInput from dto.Request before ModelPriceHelper, so tiered_expr billing resolves param() from the structured request instead of reading HTTP body (which is empty in channel-test context). - attachTestBillingRequestInput: marshal dto.Request → RequestInput - ResolveIncomingBillingExprRequestInput: early-return when pre-filled - settleTestQuota / buildTestLogOther: align test settlement & logging with production TryTieredSettle / InjectTieredBillingInfo paths --- controller/channel-test.go | 68 ++++++++++++++++++---- controller/channel_test_internal_test.go | 71 +++++++++++++++++++++++ relay/helper/billing_expr_request.go | 35 +++++++++++ relay/helper/billing_expr_request_test.go | 28 +++++++++ relay/helper/price_test.go | 62 ++++++++++++++++++++ 5 files changed, 252 insertions(+), 12 deletions(-) create mode 100644 controller/channel_test_internal_test.go create mode 100644 relay/helper/price_test.go diff --git a/controller/channel-test.go b/controller/channel-test.go index bdd67d27..857a34d6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/middleware" "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/pkg/billingexpr" "github.com/QuantumNous/new-api/relay" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" @@ -232,6 +233,15 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, info.IsChannelTest = true info.InitChannelMeta(c) + err = attachTestBillingRequestInput(info, request) + if err != nil { + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed), + } + } + err = helper.ModelMappedHelper(c, info, request) if err != nil { return testResult{ @@ -468,21 +478,11 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, } info.SetEstimatePromptTokens(usage.PromptTokens) - quota := 0 - if !priceData.UsePrice { - quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio)) - quota = int(math.Round(float64(quota) * priceData.ModelRatio)) - if priceData.ModelRatio != 0 && quota <= 0 { - quota = 1 - } - } else { - quota = int(priceData.ModelPrice * common.QuotaPerUnit) - } + quota, tieredResult := settleTestQuota(info, priceData, usage) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() consumedTime := float64(milliseconds) / 1000.0 - other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio, - usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + other := buildTestLogOther(c, info, priceData, usage, tieredResult) model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{ ChannelId: channel.Id, PromptTokens: usage.PromptTokens, @@ -504,6 +504,50 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, } } +func attachTestBillingRequestInput(info *relaycommon.RelayInfo, request dto.Request) error { + if info == nil { + return nil + } + + input, err := helper.BuildBillingExprRequestInputFromRequest(request, info.RequestHeaders) + if err != nil { + return err + } + info.BillingRequestInput = &input + return nil +} + +func settleTestQuota(info *relaycommon.RelayInfo, priceData types.PriceData, usage *dto.Usage) (int, *billingexpr.TieredResult) { + if usage != nil && info != nil && info.TieredBillingSnapshot != nil { + isClaudeUsageSemantic := usage.UsageSemantic == "anthropic" || info.GetFinalRequestRelayFormat() == types.RelayFormatClaude + usedVars := billingexpr.UsedVars(info.TieredBillingSnapshot.ExprString) + if ok, quota, result := service.TryTieredSettle(info, service.BuildTieredTokenParams(usage, isClaudeUsageSemantic, usedVars)); ok { + return quota, result + } + } + + quota := 0 + if !priceData.UsePrice { + quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio)) + quota = int(math.Round(float64(quota) * priceData.ModelRatio)) + if priceData.ModelRatio != 0 && quota <= 0 { + quota = 1 + } + return quota, nil + } + + return int(priceData.ModelPrice * common.QuotaPerUnit), nil +} + +func buildTestLogOther(c *gin.Context, info *relaycommon.RelayInfo, priceData types.PriceData, usage *dto.Usage, tieredResult *billingexpr.TieredResult) map[string]interface{} { + other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio, + usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + if tieredResult != nil { + service.InjectTieredBillingInfo(other, info, tieredResult) + } + return other +} + func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) { switch u := usageAny.(type) { case *dto.Usage: diff --git a/controller/channel_test_internal_test.go b/controller/channel_test_internal_test.go new file mode 100644 index 00000000..9c26d623 --- /dev/null +++ b/controller/channel_test_internal_test.go @@ -0,0 +1,71 @@ +package controller + +import ( + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/pkg/billingexpr" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/types" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestSettleTestQuotaUsesTieredBilling(t *testing.T) { + info := &relaycommon.RelayInfo{ + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: `param("stream") == true ? tier("stream", p * 3) : tier("base", p * 2)`, + ExprHash: billingexpr.ExprHashString(`param("stream") == true ? tier("stream", p * 3) : tier("base", p * 2)`), + GroupRatio: 1, + EstimatedTier: "stream", + QuotaPerUnit: common.QuotaPerUnit, + ExprVersion: 1, + }, + BillingRequestInput: &billingexpr.RequestInput{ + Body: []byte(`{"stream":true}`), + }, + } + + quota, result := settleTestQuota(info, types.PriceData{ + ModelRatio: 1, + CompletionRatio: 2, + }, &dto.Usage{ + PromptTokens: 1000, + }) + + require.Equal(t, 1500, quota) + require.NotNil(t, result) + require.Equal(t, "stream", result.MatchedTier) +} + +func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + + info := &relaycommon.RelayInfo{ + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: `tier("base", p * 2)`, + }, + ChannelMeta: &relaycommon.ChannelMeta{}, + } + priceData := types.PriceData{ + GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1}, + } + usage := &dto.Usage{ + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 12, + }, + } + + other := buildTestLogOther(ctx, info, priceData, usage, &billingexpr.TieredResult{ + MatchedTier: "base", + }) + + require.Equal(t, "tiered_expr", other["billing_mode"]) + require.Equal(t, "base", other["matched_tier"]) + require.NotEmpty(t, other["expr_b64"]) +} diff --git a/relay/helper/billing_expr_request.go b/relay/helper/billing_expr_request.go index 404a348f..636dee52 100644 --- a/relay/helper/billing_expr_request.go +++ b/relay/helper/billing_expr_request.go @@ -4,12 +4,21 @@ import ( "strings" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/gin-gonic/gin" ) func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.RelayInfo) (billingexpr.RequestInput, error) { + if info != nil && info.BillingRequestInput != nil { + input := cloneRequestInput(*info.BillingRequestInput) + if len(input.Headers) == 0 { + input.Headers = cloneStringMap(info.RequestHeaders) + } + return input, nil + } + input := billingexpr.RequestInput{} if info != nil { input.Headers = cloneStringMap(info.RequestHeaders) @@ -23,6 +32,22 @@ func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.Re return input, nil } +func BuildBillingExprRequestInputFromRequest(request dto.Request, headers map[string]string) (billingexpr.RequestInput, error) { + input := billingexpr.RequestInput{ + Headers: cloneStringMap(headers), + } + if request == nil { + return input, nil + } + + bodyBytes, err := common.Marshal(request) + if err != nil { + return billingexpr.RequestInput{}, err + } + input.Body = bodyBytes + return input, nil +} + func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) { if c == nil || c.Request == nil || !isJSONContentType(c.Request.Header.Get("Content-Type")) { return nil, nil @@ -34,6 +59,16 @@ func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) { return storage.Bytes() } +func cloneRequestInput(src billingexpr.RequestInput) billingexpr.RequestInput { + input := billingexpr.RequestInput{ + Headers: cloneStringMap(src.Headers), + } + if len(src.Body) > 0 { + input.Body = append([]byte(nil), src.Body...) + } + return input +} + func isJSONContentType(contentType string) bool { contentType = strings.ToLower(strings.TrimSpace(contentType)) return strings.HasPrefix(contentType, "application/json") diff --git a/relay/helper/billing_expr_request_test.go b/relay/helper/billing_expr_request_test.go index c07aaa29..9193f4b4 100644 --- a/relay/helper/billing_expr_request_test.go +++ b/relay/helper/billing_expr_request_test.go @@ -8,9 +8,12 @@ import ( "testing" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/gin-gonic/gin" + "github.com/samber/lo" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestResolveIncomingBillingExprRequestInput(t *testing.T) { @@ -33,3 +36,28 @@ func TestResolveIncomingBillingExprRequestInput(t *testing.T) { require.Equal(t, body, input.Body) require.Equal(t, "application/json", input.Headers["Content-Type"]) } + +func TestBuildBillingExprRequestInputFromRequest(t *testing.T) { + request := &dto.GeneralOpenAIRequest{ + Model: "gemini-3.1-pro-preview", + Stream: lo.ToPtr(true), + Messages: []dto.Message{ + { + Role: "user", + Content: "hi", + }, + }, + MaxTokens: lo.ToPtr(uint(3000)), + } + + input, err := BuildBillingExprRequestInputFromRequest(request, map[string]string{ + "Content-Type": "application/json", + "X-Test": "1", + }) + require.NoError(t, err) + require.Equal(t, "application/json", input.Headers["Content-Type"]) + require.Equal(t, "1", input.Headers["X-Test"]) + require.True(t, gjson.GetBytes(input.Body, "stream").Bool()) + require.Equal(t, "user", gjson.GetBytes(input.Body, "messages.0.role").String()) + require.Equal(t, float64(3000), gjson.GetBytes(input.Body, "max_tokens").Float()) +} diff --git a/relay/helper/price_test.go b/relay/helper/price_test.go new file mode 100644 index 00000000..afa64c4b --- /dev/null +++ b/relay/helper/price_test.go @@ -0,0 +1,62 @@ +package helper + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/pkg/billingexpr" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/billing_setting" + "github.com/QuantumNous/new-api/setting/config" + "github.com/QuantumNous/new-api/types" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestModelPriceHelperTieredUsesPreloadedRequestInput(t *testing.T) { + gin.SetMode(gin.TestMode) + + saved := map[string]string{} + require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error { + saved[key] = value + return nil + })) + t.Cleanup(func() { + require.NoError(t, config.GlobalConfig.LoadFromDB(saved)) + }) + + require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{ + "billing_setting.billing_mode": `{"tiered-test-model":"tiered_expr"}`, + "billing_setting.billing_expr": `{"tiered-test-model":"param(\"stream\") == true ? tier(\"stream\", p * 3) : tier(\"base\", p * 2)"}`, + })) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/channel/test/1", nil) + req.Body = nil + req.ContentLength = 0 + req.Header.Set("Content-Type", "application/json") + ctx.Request = req + ctx.Set("group", "default") + + info := &relaycommon.RelayInfo{ + OriginModelName: "tiered-test-model", + UserGroup: "default", + UsingGroup: "default", + RequestHeaders: map[string]string{"Content-Type": "application/json"}, + BillingRequestInput: &billingexpr.RequestInput{ + Headers: map[string]string{"Content-Type": "application/json"}, + Body: []byte(`{"stream":true}`), + }, + } + + priceData, err := ModelPriceHelper(ctx, info, 1000, &types.TokenCountMeta{}) + require.NoError(t, err) + require.Equal(t, 1500, priceData.QuotaToPreConsume) + require.NotNil(t, info.TieredBillingSnapshot) + require.Equal(t, "stream", info.TieredBillingSnapshot.EstimatedTier) + require.Equal(t, billing_setting.BillingModeTieredExpr, info.TieredBillingSnapshot.BillingMode) + require.Equal(t, common.QuotaPerUnit, info.TieredBillingSnapshot.QuotaPerUnit) +}