diff --git a/service/text_quota.go b/service/text_quota.go index 653166fe..e2ec87b2 100644 --- a/service/text_quota.go +++ b/service/text_quota.go @@ -52,6 +52,7 @@ type textQuotaSummary struct { FileSearchCallCount int AudioInputPrice float64 ImageGenerationCallPrice float64 + ToolCallSurchargeQuota decimal.Decimal } func cacheWriteTokensTotal(summary textQuotaSummary) int { @@ -78,6 +79,81 @@ func isLegacyClaudeDerivedOpenAIUsage(relayInfo *relaycommon.RelayInfo, usage *d return usage.ClaudeCacheCreation5mTokens > 0 || usage.ClaudeCacheCreation1hTokens > 0 } +func calculateTextToolCallSurcharge(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, summary *textQuotaSummary) decimal.Decimal { + dGroupRatio := decimal.NewFromFloat(summary.GroupRatio) + dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) + + var surcharge decimal.Decimal + + if relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { + summary.WebSearchCallCount = webSearchTool.CallCount + summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName) + surcharge = surcharge.Add(decimal.NewFromFloat(summary.WebSearchPrice). + Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)). + Mul(dGroupRatio). + Mul(dQuotaPerUnit)) + } + } else if strings.HasSuffix(summary.ModelName, "search-preview") { + summary.WebSearchCallCount = 1 + summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName) + surcharge = surcharge.Add(decimal.NewFromFloat(summary.WebSearchPrice). + Div(decimal.NewFromInt(1000)). + Mul(dGroupRatio). + Mul(dQuotaPerUnit)) + } + + summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests") + if summary.ClaudeWebSearchCallCount > 0 { + summary.ClaudeWebSearchPrice = operation_setting.GetToolPrice("web_search") + surcharge = surcharge.Add(decimal.NewFromFloat(summary.ClaudeWebSearchPrice). + Div(decimal.NewFromInt(1000)). + Mul(dGroupRatio). + Mul(dQuotaPerUnit). + Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount)))) + } + + if relayInfo.ResponsesUsageInfo != nil { + if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 { + summary.FileSearchCallCount = fileSearchTool.CallCount + summary.FileSearchPrice = operation_setting.GetToolPrice("file_search") + surcharge = surcharge.Add(decimal.NewFromFloat(summary.FileSearchPrice). + Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)). + Mul(dGroupRatio). + Mul(dQuotaPerUnit)) + } + } + + if ctx.GetBool("image_generation_call") { + summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size")) + surcharge = surcharge.Add(decimal.NewFromFloat(summary.ImageGenerationCallPrice). + Mul(dGroupRatio). + Mul(dQuotaPerUnit)) + } + + return surcharge +} + +func composeTieredTextQuota(relayInfo *relaycommon.RelayInfo, summary textQuotaSummary, tieredQuota int, tieredResult *billingexpr.TieredResult) int { + if summary.ToolCallSurchargeQuota.IsZero() { + return tieredQuota + } + + if tieredResult != nil { + if snap := relayInfo.TieredBillingSnapshot; snap != nil { + return int(decimal.NewFromFloat(tieredResult.ActualQuotaBeforeGroup). + Mul(decimal.NewFromFloat(snap.GroupRatio)). + Add(summary.ToolCallSurchargeQuota). + Round(0). + IntPart()) + } + } + + return tieredQuota + int(summary.ToolCallSurchargeQuota.Round(0).IntPart()) +} + func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) textQuotaSummary { summary := textQuotaSummary{ ModelName: relayInfo.OriginModelName, @@ -148,52 +224,7 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) ratio := dModelRatio.Mul(dGroupRatio) - - var dWebSearchQuota decimal.Decimal - if relayInfo.ResponsesUsageInfo != nil { - if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { - summary.WebSearchCallCount = webSearchTool.CallCount - summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName) - dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice). - Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). - Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - } - } else if strings.HasSuffix(summary.ModelName, "search-preview") { - searchContextSize := ctx.GetString("chat_completion_web_search_context_size") - if searchContextSize == "" { - searchContextSize = "medium" - } - summary.WebSearchCallCount = 1 - summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName) - dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice). - Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - } - - var dClaudeWebSearchQuota decimal.Decimal - summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests") - if summary.ClaudeWebSearchCallCount > 0 { - summary.ClaudeWebSearchPrice = operation_setting.GetToolPrice("web_search") - dClaudeWebSearchQuota = decimal.NewFromFloat(summary.ClaudeWebSearchPrice). - Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit). - Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount))) - } - - var dFileSearchQuota decimal.Decimal - if relayInfo.ResponsesUsageInfo != nil { - if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 { - summary.FileSearchCallCount = fileSearchTool.CallCount - summary.FileSearchPrice = operation_setting.GetToolPrice("file_search") - dFileSearchQuota = decimal.NewFromFloat(summary.FileSearchPrice). - Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). - Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - } - } - - var dImageGenerationCallQuota decimal.Decimal - if ctx.GetBool("image_generation_call") { - summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size")) - dImageGenerationCallQuota = decimal.NewFromFloat(summary.ImageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit) - } + summary.ToolCallSurchargeQuota = calculateTextToolCallSurcharge(ctx, relayInfo, &summary) var audioInputQuota decimal.Decimal if !relayInfo.PriceData.UsePrice { @@ -242,11 +273,8 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio).Add(cachedCreationTokensWithRatio) completionQuota := dCompletionTokens.Mul(dCompletionRatio) quotaCalculateDecimal := promptQuota.Add(completionQuota).Mul(ratio) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) + quotaCalculateDecimal = quotaCalculateDecimal.Add(summary.ToolCallSurchargeQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota) if len(relayInfo.PriceData.OtherRatios) > 0 { for _, otherRatio := range relayInfo.PriceData.OtherRatios { @@ -260,11 +288,8 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart()) } else { quotaCalculateDecimal := dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) + quotaCalculateDecimal = quotaCalculateDecimal.Add(summary.ToolCallSurchargeQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) - quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota) if len(relayInfo.PriceData.OtherRatios) > 0 { for _, otherRatio := range relayInfo.PriceData.OtherRatios { quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio)) @@ -313,7 +338,7 @@ func PostTextConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, us tieredOk, tieredQuota, tieredRes := TryTieredSettle(relayInfo, BuildTieredTokenParams(usage, summary.IsClaudeUsageSemantic, tieredUsedVars)) if tieredOk { tieredResult = tieredRes - summary.Quota = tieredQuota + summary.Quota = composeTieredTextQuota(relayInfo, summary, tieredQuota, tieredRes) } } diff --git a/service/text_quota_test.go b/service/text_quota_test.go index e995de17..37ce1877 100644 --- a/service/text_quota_test.go +++ b/service/text_quota_test.go @@ -7,6 +7,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/types" @@ -316,3 +317,125 @@ func TestCalculateTextQuotaSummaryKeepsPrePRClaudeOpenRouterBilling(t *testing.T require.Equal(t, 172, summary.PromptTokens) require.Equal(t, 798, summary.Quota) } + +func TestComposeTieredTextQuotaKeepsToolCallSurcharges(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Set("image_generation_call", true) + ctx.Set("image_generation_call_quality", "low") + ctx.Set("image_generation_call_size", "1024x1024") + + relayInfo := &relaycommon.RelayInfo{ + OriginModelName: "o1", + PriceData: types.PriceData{ + ModelRatio: 1, + CompletionRatio: 1, + GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1}, + }, + ResponsesUsageInfo: &relaycommon.ResponsesUsageInfo{ + BuiltInTools: map[string]*relaycommon.BuildInToolInfo{ + dto.BuildInToolWebSearchPreview: &relaycommon.BuildInToolInfo{ + CallCount: 1, + }, + dto.BuildInToolFileSearch: &relaycommon.BuildInToolInfo{ + CallCount: 2, + }, + }, + }, + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + GroupRatio: 1, + EstimatedQuotaBeforeGroup: 1000, + }, + StartTime: time.Now(), + } + + usage := &dto.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + + summary := calculateTextQuotaSummary(ctx, relayInfo, usage) + quota := composeTieredTextQuota(relayInfo, summary, 1000, &billingexpr.TieredResult{ + ActualQuotaBeforeGroup: 1000, + ActualQuotaAfterGroup: 1000, + }) + + require.Equal(t, int64(13000), summary.ToolCallSurchargeQuota.Round(0).IntPart()) + require.Equal(t, 14000, quota) +} + +func TestComposeTieredTextQuotaFallbackKeepsToolCallSurcharges(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Set("claude_web_search_requests", 2) + + relayInfo := &relaycommon.RelayInfo{ + OriginModelName: "claude-3-7-sonnet", + PriceData: types.PriceData{ + ModelRatio: 1, + CompletionRatio: 1, + GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1.25}, + }, + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + GroupRatio: 1.25, + EstimatedQuotaBeforeGroup: 1000, + }, + StartTime: time.Now(), + } + + usage := &dto.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + + summary := calculateTextQuotaSummary(ctx, relayInfo, usage) + quota := composeTieredTextQuota(relayInfo, summary, 1250, nil) + + require.Equal(t, int64(12500), summary.ToolCallSurchargeQuota.Round(0).IntPart()) + require.Equal(t, 13750, quota) +} + +func TestComposeTieredTextQuotaErrorFallbackUsesPreConsumedQuota(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Set("claude_web_search_requests", 2) + + relayInfo := &relaycommon.RelayInfo{ + OriginModelName: "claude-3-7-sonnet", + PriceData: types.PriceData{ + ModelRatio: 1, + CompletionRatio: 1, + GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1.25}, + }, + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + GroupRatio: 1.25, + EstimatedQuotaBeforeGroup: 1000, + }, + StartTime: time.Now(), + } + + usage := &dto.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + + summary := calculateTextQuotaSummary(ctx, relayInfo, usage) + + // tieredResult=nil simulates a settlement error where TryTieredSettle + // falls back to FinalPreConsumedQuota (2000), which differs from + // EstimatedQuotaBeforeGroup * GroupRatio (1250). + preConsumedFallback := 2000 + quota := composeTieredTextQuota(relayInfo, summary, preConsumedFallback, nil) + + require.Equal(t, int64(12500), summary.ToolCallSurchargeQuota.Round(0).IntPart()) + require.Equal(t, 14500, quota) +}