fix(channel-test): support tiered billing model tests (#4145)
Pre-fill BillingRequestInput from dto.Request before ModelPriceHelper, so tiered_expr billing resolves param() from the structured request instead of reading HTTP body (which is empty in channel-test context). - attachTestBillingRequestInput: marshal dto.Request → RequestInput - ResolveIncomingBillingExprRequestInput: early-return when pre-filled - settleTestQuota / buildTestLogOther: align test settlement & logging with production TryTieredSettle / InjectTieredBillingInfo paths
This commit is contained in:
+56
-12
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/dto"
|
"github.com/QuantumNous/new-api/dto"
|
||||||
"github.com/QuantumNous/new-api/middleware"
|
"github.com/QuantumNous/new-api/middleware"
|
||||||
"github.com/QuantumNous/new-api/model"
|
"github.com/QuantumNous/new-api/model"
|
||||||
|
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||||
"github.com/QuantumNous/new-api/relay"
|
"github.com/QuantumNous/new-api/relay"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||||
@@ -232,6 +233,15 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
|||||||
info.IsChannelTest = true
|
info.IsChannelTest = true
|
||||||
info.InitChannelMeta(c)
|
info.InitChannelMeta(c)
|
||||||
|
|
||||||
|
err = attachTestBillingRequestInput(info, request)
|
||||||
|
if err != nil {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, info, request)
|
err = helper.ModelMappedHelper(c, info, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return testResult{
|
return testResult{
|
||||||
@@ -468,21 +478,11 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
|||||||
}
|
}
|
||||||
info.SetEstimatePromptTokens(usage.PromptTokens)
|
info.SetEstimatePromptTokens(usage.PromptTokens)
|
||||||
|
|
||||||
quota := 0
|
quota, tieredResult := settleTestQuota(info, priceData, usage)
|
||||||
if !priceData.UsePrice {
|
|
||||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
|
||||||
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
|
||||||
if priceData.ModelRatio != 0 && quota <= 0 {
|
|
||||||
quota = 1
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
|
|
||||||
}
|
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
|
||||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
|
||||||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||||||
ChannelId: channel.Id,
|
ChannelId: channel.Id,
|
||||||
PromptTokens: usage.PromptTokens,
|
PromptTokens: usage.PromptTokens,
|
||||||
@@ -504,6 +504,50 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func attachTestBillingRequestInput(info *relaycommon.RelayInfo, request dto.Request) error {
|
||||||
|
if info == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
input, err := helper.BuildBillingExprRequestInputFromRequest(request, info.RequestHeaders)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
info.BillingRequestInput = &input
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func settleTestQuota(info *relaycommon.RelayInfo, priceData types.PriceData, usage *dto.Usage) (int, *billingexpr.TieredResult) {
|
||||||
|
if usage != nil && info != nil && info.TieredBillingSnapshot != nil {
|
||||||
|
isClaudeUsageSemantic := usage.UsageSemantic == "anthropic" || info.GetFinalRequestRelayFormat() == types.RelayFormatClaude
|
||||||
|
usedVars := billingexpr.UsedVars(info.TieredBillingSnapshot.ExprString)
|
||||||
|
if ok, quota, result := service.TryTieredSettle(info, service.BuildTieredTokenParams(usage, isClaudeUsageSemantic, usedVars)); ok {
|
||||||
|
return quota, result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
quota := 0
|
||||||
|
if !priceData.UsePrice {
|
||||||
|
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
||||||
|
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
||||||
|
if priceData.ModelRatio != 0 && quota <= 0 {
|
||||||
|
quota = 1
|
||||||
|
}
|
||||||
|
return quota, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(priceData.ModelPrice * common.QuotaPerUnit), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTestLogOther(c *gin.Context, info *relaycommon.RelayInfo, priceData types.PriceData, usage *dto.Usage, tieredResult *billingexpr.TieredResult) map[string]interface{} {
|
||||||
|
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||||
|
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||||
|
if tieredResult != nil {
|
||||||
|
service.InjectTieredBillingInfo(other, info, tieredResult)
|
||||||
|
}
|
||||||
|
return other
|
||||||
|
}
|
||||||
|
|
||||||
func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
|
func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
|
||||||
switch u := usageAny.(type) {
|
switch u := usageAny.(type) {
|
||||||
case *dto.Usage:
|
case *dto.Usage:
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSettleTestQuotaUsesTieredBilling(t *testing.T) {
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
|
||||||
|
BillingMode: "tiered_expr",
|
||||||
|
ExprString: `param("stream") == true ? tier("stream", p * 3) : tier("base", p * 2)`,
|
||||||
|
ExprHash: billingexpr.ExprHashString(`param("stream") == true ? tier("stream", p * 3) : tier("base", p * 2)`),
|
||||||
|
GroupRatio: 1,
|
||||||
|
EstimatedTier: "stream",
|
||||||
|
QuotaPerUnit: common.QuotaPerUnit,
|
||||||
|
ExprVersion: 1,
|
||||||
|
},
|
||||||
|
BillingRequestInput: &billingexpr.RequestInput{
|
||||||
|
Body: []byte(`{"stream":true}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
quota, result := settleTestQuota(info, types.PriceData{
|
||||||
|
ModelRatio: 1,
|
||||||
|
CompletionRatio: 2,
|
||||||
|
}, &dto.Usage{
|
||||||
|
PromptTokens: 1000,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, 1500, quota)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "stream", result.MatchedTier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
|
||||||
|
BillingMode: "tiered_expr",
|
||||||
|
ExprString: `tier("base", p * 2)`,
|
||||||
|
},
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||||
|
}
|
||||||
|
priceData := types.PriceData{
|
||||||
|
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||||
|
}
|
||||||
|
usage := &dto.Usage{
|
||||||
|
PromptTokensDetails: dto.InputTokenDetails{
|
||||||
|
CachedTokens: 12,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
other := buildTestLogOther(ctx, info, priceData, usage, &billingexpr.TieredResult{
|
||||||
|
MatchedTier: "base",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, "tiered_expr", other["billing_mode"])
|
||||||
|
require.Equal(t, "base", other["matched_tier"])
|
||||||
|
require.NotEmpty(t, other["expr_b64"])
|
||||||
|
}
|
||||||
@@ -4,12 +4,21 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.RelayInfo) (billingexpr.RequestInput, error) {
|
func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.RelayInfo) (billingexpr.RequestInput, error) {
|
||||||
|
if info != nil && info.BillingRequestInput != nil {
|
||||||
|
input := cloneRequestInput(*info.BillingRequestInput)
|
||||||
|
if len(input.Headers) == 0 {
|
||||||
|
input.Headers = cloneStringMap(info.RequestHeaders)
|
||||||
|
}
|
||||||
|
return input, nil
|
||||||
|
}
|
||||||
|
|
||||||
input := billingexpr.RequestInput{}
|
input := billingexpr.RequestInput{}
|
||||||
if info != nil {
|
if info != nil {
|
||||||
input.Headers = cloneStringMap(info.RequestHeaders)
|
input.Headers = cloneStringMap(info.RequestHeaders)
|
||||||
@@ -23,6 +32,22 @@ func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.Re
|
|||||||
return input, nil
|
return input, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BuildBillingExprRequestInputFromRequest(request dto.Request, headers map[string]string) (billingexpr.RequestInput, error) {
|
||||||
|
input := billingexpr.RequestInput{
|
||||||
|
Headers: cloneStringMap(headers),
|
||||||
|
}
|
||||||
|
if request == nil {
|
||||||
|
return input, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := common.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return billingexpr.RequestInput{}, err
|
||||||
|
}
|
||||||
|
input.Body = bodyBytes
|
||||||
|
return input, nil
|
||||||
|
}
|
||||||
|
|
||||||
func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) {
|
func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) {
|
||||||
if c == nil || c.Request == nil || !isJSONContentType(c.Request.Header.Get("Content-Type")) {
|
if c == nil || c.Request == nil || !isJSONContentType(c.Request.Header.Get("Content-Type")) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -34,6 +59,16 @@ func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) {
|
|||||||
return storage.Bytes()
|
return storage.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneRequestInput(src billingexpr.RequestInput) billingexpr.RequestInput {
|
||||||
|
input := billingexpr.RequestInput{
|
||||||
|
Headers: cloneStringMap(src.Headers),
|
||||||
|
}
|
||||||
|
if len(src.Body) > 0 {
|
||||||
|
input.Body = append([]byte(nil), src.Body...)
|
||||||
|
}
|
||||||
|
return input
|
||||||
|
}
|
||||||
|
|
||||||
func isJSONContentType(contentType string) bool {
|
func isJSONContentType(contentType string) bool {
|
||||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||||
return strings.HasPrefix(contentType, "application/json")
|
return strings.HasPrefix(contentType, "application/json")
|
||||||
|
|||||||
@@ -8,9 +8,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/samber/lo"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestResolveIncomingBillingExprRequestInput(t *testing.T) {
|
func TestResolveIncomingBillingExprRequestInput(t *testing.T) {
|
||||||
@@ -33,3 +36,28 @@ func TestResolveIncomingBillingExprRequestInput(t *testing.T) {
|
|||||||
require.Equal(t, body, input.Body)
|
require.Equal(t, body, input.Body)
|
||||||
require.Equal(t, "application/json", input.Headers["Content-Type"])
|
require.Equal(t, "application/json", input.Headers["Content-Type"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildBillingExprRequestInputFromRequest(t *testing.T) {
|
||||||
|
request := &dto.GeneralOpenAIRequest{
|
||||||
|
Model: "gemini-3.1-pro-preview",
|
||||||
|
Stream: lo.ToPtr(true),
|
||||||
|
Messages: []dto.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hi",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MaxTokens: lo.ToPtr(uint(3000)),
|
||||||
|
}
|
||||||
|
|
||||||
|
input, err := BuildBillingExprRequestInputFromRequest(request, map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-Test": "1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "application/json", input.Headers["Content-Type"])
|
||||||
|
require.Equal(t, "1", input.Headers["X-Test"])
|
||||||
|
require.True(t, gjson.GetBytes(input.Body, "stream").Bool())
|
||||||
|
require.Equal(t, "user", gjson.GetBytes(input.Body, "messages.0.role").String())
|
||||||
|
require.Equal(t, float64(3000), gjson.GetBytes(input.Body, "max_tokens").Float())
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/setting/billing_setting"
|
||||||
|
"github.com/QuantumNous/new-api/setting/config"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModelPriceHelperTieredUsesPreloadedRequestInput(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
saved := map[string]string{}
|
||||||
|
require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
|
||||||
|
saved[key] = value
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
|
||||||
|
"billing_setting.billing_mode": `{"tiered-test-model":"tiered_expr"}`,
|
||||||
|
"billing_setting.billing_expr": `{"tiered-test-model":"param(\"stream\") == true ? tier(\"stream\", p * 3) : tier(\"base\", p * 2)"}`,
|
||||||
|
}))
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/channel/test/1", nil)
|
||||||
|
req.Body = nil
|
||||||
|
req.ContentLength = 0
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
ctx.Request = req
|
||||||
|
ctx.Set("group", "default")
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
OriginModelName: "tiered-test-model",
|
||||||
|
UserGroup: "default",
|
||||||
|
UsingGroup: "default",
|
||||||
|
RequestHeaders: map[string]string{"Content-Type": "application/json"},
|
||||||
|
BillingRequestInput: &billingexpr.RequestInput{
|
||||||
|
Headers: map[string]string{"Content-Type": "application/json"},
|
||||||
|
Body: []byte(`{"stream":true}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
priceData, err := ModelPriceHelper(ctx, info, 1000, &types.TokenCountMeta{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1500, priceData.QuotaToPreConsume)
|
||||||
|
require.NotNil(t, info.TieredBillingSnapshot)
|
||||||
|
require.Equal(t, "stream", info.TieredBillingSnapshot.EstimatedTier)
|
||||||
|
require.Equal(t, billing_setting.BillingModeTieredExpr, info.TieredBillingSnapshot.BillingMode)
|
||||||
|
require.Equal(t, common.QuotaPerUnit, info.TieredBillingSnapshot.QuotaPerUnit)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user