From 1fe9f6f9892fdc3ef8bdc6fa2436c63cea66622c Mon Sep 17 00:00:00 2001 From: yyhhyyyyyy Date: Thu, 9 Apr 2026 18:18:01 +0800 Subject: [PATCH 1/2] fix(billing): preserve text tool surcharges in tiered settlement --- mise.toml | 2 + service/text_quota.go | 143 +++++++++++++++++++++++-------------- service/text_quota_test.go | 84 ++++++++++++++++++++++ 3 files changed, 174 insertions(+), 55 deletions(-) create mode 100644 mise.toml diff --git a/mise.toml b/mise.toml new file mode 100644 index 00000000..a94d1edc --- /dev/null +++ b/mise.toml @@ -0,0 +1,2 @@ +[tools] +bun = "latest" diff --git a/service/text_quota.go b/service/text_quota.go index 653166fe..07017ad1 100644 --- a/service/text_quota.go +++ b/service/text_quota.go @@ -52,6 +52,7 @@ type textQuotaSummary struct { FileSearchCallCount int AudioInputPrice float64 ImageGenerationCallPrice float64 + ToolCallSurchargeQuota decimal.Decimal } func cacheWriteTokensTotal(summary textQuotaSummary) int { @@ -78,6 +79,89 @@ func isLegacyClaudeDerivedOpenAIUsage(relayInfo *relaycommon.RelayInfo, usage *d return usage.ClaudeCacheCreation5mTokens > 0 || usage.ClaudeCacheCreation1hTokens > 0 } +func calculateTextToolCallSurcharge(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, summary *textQuotaSummary) decimal.Decimal { + dGroupRatio := decimal.NewFromFloat(summary.GroupRatio) + dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) + + var surcharge decimal.Decimal + + if relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { + summary.WebSearchCallCount = webSearchTool.CallCount + summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName) + surcharge = surcharge.Add(decimal.NewFromFloat(summary.WebSearchPrice). + Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)). + Mul(dGroupRatio). + Mul(dQuotaPerUnit)) + } + } else if strings.HasSuffix(summary.ModelName, "search-preview") { + summary.WebSearchCallCount = 1 + summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName) + surcharge = surcharge.Add(decimal.NewFromFloat(summary.WebSearchPrice). + Div(decimal.NewFromInt(1000)). + Mul(dGroupRatio). + Mul(dQuotaPerUnit)) + } + + summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests") + if summary.ClaudeWebSearchCallCount > 0 { + summary.ClaudeWebSearchPrice = operation_setting.GetToolPrice("web_search") + surcharge = surcharge.Add(decimal.NewFromFloat(summary.ClaudeWebSearchPrice). + Div(decimal.NewFromInt(1000)). + Mul(dGroupRatio). + Mul(dQuotaPerUnit). + Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount)))) + } + + if relayInfo.ResponsesUsageInfo != nil { + if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 { + summary.FileSearchCallCount = fileSearchTool.CallCount + summary.FileSearchPrice = operation_setting.GetToolPrice("file_search") + surcharge = surcharge.Add(decimal.NewFromFloat(summary.FileSearchPrice). + Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)). + Mul(dGroupRatio). + Mul(dQuotaPerUnit)) + } + } + + if ctx.GetBool("image_generation_call") { + summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size")) + surcharge = surcharge.Add(decimal.NewFromFloat(summary.ImageGenerationCallPrice). + Mul(dGroupRatio). + Mul(dQuotaPerUnit)) + } + + return surcharge +} + +func composeTieredTextQuota(relayInfo *relaycommon.RelayInfo, summary textQuotaSummary, tieredQuota int, tieredResult *billingexpr.TieredResult) int { + if summary.ToolCallSurchargeQuota.IsZero() { + return tieredQuota + } + + if tieredResult != nil { + if snap := relayInfo.TieredBillingSnapshot; snap != nil { + return int(decimal.NewFromFloat(tieredResult.ActualQuotaBeforeGroup). + Mul(decimal.NewFromFloat(snap.GroupRatio)). + Add(summary.ToolCallSurchargeQuota). + Round(0). + IntPart()) + } + } + + if snap := relayInfo.TieredBillingSnapshot; snap != nil { + return int(decimal.NewFromFloat(snap.EstimatedQuotaBeforeGroup). + Mul(decimal.NewFromFloat(snap.GroupRatio)). + Add(summary.ToolCallSurchargeQuota). + Round(0). + IntPart()) + } + + return tieredQuota + int(summary.ToolCallSurchargeQuota.Round(0).IntPart()) +} + func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) textQuotaSummary { summary := textQuotaSummary{ ModelName: relayInfo.OriginModelName, @@ -148,52 +232,7 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) ratio := dModelRatio.Mul(dGroupRatio) - - var dWebSearchQuota decimal.Decimal - if relayInfo.ResponsesUsageInfo != nil { - if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { - summary.WebSearchCallCount = webSearchTool.CallCount - summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName) - dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice). - Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). - Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - } - } else if strings.HasSuffix(summary.ModelName, "search-preview") { - searchContextSize := ctx.GetString("chat_completion_web_search_context_size") - if searchContextSize == "" { - searchContextSize = "medium" - } - summary.WebSearchCallCount = 1 - summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName) - dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice). - Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - } - - var dClaudeWebSearchQuota decimal.Decimal - summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests") - if summary.ClaudeWebSearchCallCount > 0 { - summary.ClaudeWebSearchPrice = operation_setting.GetToolPrice("web_search") - dClaudeWebSearchQuota = decimal.NewFromFloat(summary.ClaudeWebSearchPrice). - Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit). - Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount))) - } - - var dFileSearchQuota decimal.Decimal - if relayInfo.ResponsesUsageInfo != nil { - if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 { - summary.FileSearchCallCount = fileSearchTool.CallCount - summary.FileSearchPrice = operation_setting.GetToolPrice("file_search") - dFileSearchQuota = decimal.NewFromFloat(summary.FileSearchPrice). - Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). - Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - } - } - - var dImageGenerationCallQuota decimal.Decimal - if ctx.GetBool("image_generation_call") { - summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size")) - dImageGenerationCallQuota = decimal.NewFromFloat(summary.ImageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit) - } + summary.ToolCallSurchargeQuota = calculateTextToolCallSurcharge(ctx, relayInfo, &summary) var audioInputQuota decimal.Decimal if !relayInfo.PriceData.UsePrice { @@ -242,11 +281,8 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio).Add(cachedCreationTokensWithRatio) completionQuota := dCompletionTokens.Mul(dCompletionRatio) quotaCalculateDecimal := promptQuota.Add(completionQuota).Mul(ratio) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) + quotaCalculateDecimal = quotaCalculateDecimal.Add(summary.ToolCallSurchargeQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota) if len(relayInfo.PriceData.OtherRatios) > 0 { for _, otherRatio := range relayInfo.PriceData.OtherRatios { @@ -260,11 +296,8 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart()) } else { quotaCalculateDecimal := dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) + quotaCalculateDecimal = quotaCalculateDecimal.Add(summary.ToolCallSurchargeQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota) if len(relayInfo.PriceData.OtherRatios) > 0 { for _, otherRatio := range relayInfo.PriceData.OtherRatios { quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio)) @@ -313,7 +346,7 @@ func PostTextConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, us tieredOk, tieredQuota, tieredRes := TryTieredSettle(relayInfo, BuildTieredTokenParams(usage, summary.IsClaudeUsageSemantic, tieredUsedVars)) if tieredOk { tieredResult = tieredRes - summary.Quota = tieredQuota + summary.Quota = composeTieredTextQuota(relayInfo, summary, tieredQuota, tieredRes) } } diff --git a/service/text_quota_test.go b/service/text_quota_test.go index e995de17..deeb1189 100644 --- a/service/text_quota_test.go +++ b/service/text_quota_test.go @@ -7,6 +7,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/types" @@ -316,3 +317,86 @@ func TestCalculateTextQuotaSummaryKeepsPrePRClaudeOpenRouterBilling(t *testing.T require.Equal(t, 172, summary.PromptTokens) require.Equal(t, 798, summary.Quota) } + +func TestComposeTieredTextQuotaKeepsToolCallSurcharges(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Set("image_generation_call", true) + ctx.Set("image_generation_call_quality", "low") + ctx.Set("image_generation_call_size", "1024x1024") + + relayInfo := &relaycommon.RelayInfo{ + OriginModelName: "o1", + PriceData: types.PriceData{ + ModelRatio: 1, + CompletionRatio: 1, + GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1}, + }, + ResponsesUsageInfo: &relaycommon.ResponsesUsageInfo{ + BuiltInTools: map[string]*relaycommon.BuildInToolInfo{ + dto.BuildInToolWebSearchPreview: &relaycommon.BuildInToolInfo{ + CallCount: 1, + }, + dto.BuildInToolFileSearch: &relaycommon.BuildInToolInfo{ + CallCount: 2, + }, + }, + }, + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + GroupRatio: 1, + EstimatedQuotaBeforeGroup: 1000, + }, + StartTime: time.Now(), + } + + usage := &dto.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + + summary := calculateTextQuotaSummary(ctx, relayInfo, usage) + quota := composeTieredTextQuota(relayInfo, summary, 1000, &billingexpr.TieredResult{ + ActualQuotaBeforeGroup: 1000, + ActualQuotaAfterGroup: 1000, + }) + + require.Equal(t, int64(13000), summary.ToolCallSurchargeQuota.Round(0).IntPart()) + require.Equal(t, 14000, quota) +} + +func TestComposeTieredTextQuotaFallbackKeepsToolCallSurcharges(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Set("claude_web_search_requests", 2) + + relayInfo := &relaycommon.RelayInfo{ + OriginModelName: "claude-3-7-sonnet", + PriceData: types.PriceData{ + ModelRatio: 1, + CompletionRatio: 1, + GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1.25}, + }, + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + GroupRatio: 1.25, + EstimatedQuotaBeforeGroup: 1000, + }, + StartTime: time.Now(), + } + + usage := &dto.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + + summary := calculateTextQuotaSummary(ctx, relayInfo, usage) + quota := composeTieredTextQuota(relayInfo, summary, 1250, nil) + + require.Equal(t, int64(12500), summary.ToolCallSurchargeQuota.Round(0).IntPart()) + require.Equal(t, 13750, quota) +} From 5c4ed5be99d63410bbad9a528588b87e32a23721 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 23 Apr 2026 18:59:48 +0800 Subject: [PATCH 2/2] fix(billing): use tieredQuota fallback in composeTieredTextQuota error path Remove the intermediate branch that recomputed quota from EstimatedQuotaBeforeGroup when tieredResult is nil. This discarded the FinalPreConsumedQuota fallback that TryTieredSettle already selected. Now the error path simply adds tool surcharges to the passed-in tieredQuota, preserving the existing fallback semantics. Also removes unrelated mise.toml and adds a test covering the error fallback with a pre-consumed quota that differs from the estimate. --- mise.toml | 2 -- service/text_quota.go | 8 -------- service/text_quota_test.go | 39 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 10 deletions(-) delete mode 100644 mise.toml diff --git a/mise.toml b/mise.toml deleted file mode 100644 index a94d1edc..00000000 --- a/mise.toml +++ /dev/null @@ -1,2 +0,0 @@ -[tools] -bun = "latest" diff --git a/service/text_quota.go b/service/text_quota.go index 07017ad1..e2ec87b2 100644 --- a/service/text_quota.go +++ b/service/text_quota.go @@ -151,14 +151,6 @@ func composeTieredTextQuota(relayInfo *relaycommon.RelayInfo, summary textQuotaS } } - if snap := relayInfo.TieredBillingSnapshot; snap != nil { - return int(decimal.NewFromFloat(snap.EstimatedQuotaBeforeGroup). - Mul(decimal.NewFromFloat(snap.GroupRatio)). - Add(summary.ToolCallSurchargeQuota). - Round(0). - IntPart()) - } - return tieredQuota + int(summary.ToolCallSurchargeQuota.Round(0).IntPart()) } diff --git a/service/text_quota_test.go b/service/text_quota_test.go index deeb1189..37ce1877 100644 --- a/service/text_quota_test.go +++ b/service/text_quota_test.go @@ -400,3 +400,42 @@ func TestComposeTieredTextQuotaFallbackKeepsToolCallSurcharges(t *testing.T) { require.Equal(t, int64(12500), summary.ToolCallSurchargeQuota.Round(0).IntPart()) require.Equal(t, 13750, quota) } + +func TestComposeTieredTextQuotaErrorFallbackUsesPreConsumedQuota(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Set("claude_web_search_requests", 2) + + relayInfo := &relaycommon.RelayInfo{ + OriginModelName: "claude-3-7-sonnet", + PriceData: types.PriceData{ + ModelRatio: 1, + CompletionRatio: 1, + GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1.25}, + }, + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + GroupRatio: 1.25, + EstimatedQuotaBeforeGroup: 1000, + }, + StartTime: time.Now(), + } + + usage := &dto.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + + summary := calculateTextQuotaSummary(ctx, relayInfo, usage) + + // tieredResult=nil simulates a settlement error where TryTieredSettle + // falls back to FinalPreConsumedQuota (2000), which differs from + // EstimatedQuotaBeforeGroup * GroupRatio (1250). + preConsumedFallback := 2000 + quota := composeTieredTextQuota(relayInfo, summary, preConsumedFallback, nil) + + require.Equal(t, int64(12500), summary.ToolCallSurchargeQuota.Round(0).IntPart()) + require.Equal(t, 14500, quota) +}