Merge pull request #4162 from yyhhyyyyyy/fix/tiered-text-tool-surcharge
fix(billing): preserve text tool surcharges in tiered settlement
This commit is contained in:
+80
-55
@@ -52,6 +52,7 @@ type textQuotaSummary struct {
|
|||||||
FileSearchCallCount int
|
FileSearchCallCount int
|
||||||
AudioInputPrice float64
|
AudioInputPrice float64
|
||||||
ImageGenerationCallPrice float64
|
ImageGenerationCallPrice float64
|
||||||
|
ToolCallSurchargeQuota decimal.Decimal
|
||||||
}
|
}
|
||||||
|
|
||||||
func cacheWriteTokensTotal(summary textQuotaSummary) int {
|
func cacheWriteTokensTotal(summary textQuotaSummary) int {
|
||||||
@@ -78,6 +79,81 @@ func isLegacyClaudeDerivedOpenAIUsage(relayInfo *relaycommon.RelayInfo, usage *d
|
|||||||
return usage.ClaudeCacheCreation5mTokens > 0 || usage.ClaudeCacheCreation1hTokens > 0
|
return usage.ClaudeCacheCreation5mTokens > 0 || usage.ClaudeCacheCreation1hTokens > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func calculateTextToolCallSurcharge(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, summary *textQuotaSummary) decimal.Decimal {
|
||||||
|
dGroupRatio := decimal.NewFromFloat(summary.GroupRatio)
|
||||||
|
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||||
|
|
||||||
|
var surcharge decimal.Decimal
|
||||||
|
|
||||||
|
if relayInfo.ResponsesUsageInfo != nil {
|
||||||
|
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
|
||||||
|
summary.WebSearchCallCount = webSearchTool.CallCount
|
||||||
|
summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName)
|
||||||
|
surcharge = surcharge.Add(decimal.NewFromFloat(summary.WebSearchPrice).
|
||||||
|
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||||||
|
Div(decimal.NewFromInt(1000)).
|
||||||
|
Mul(dGroupRatio).
|
||||||
|
Mul(dQuotaPerUnit))
|
||||||
|
}
|
||||||
|
} else if strings.HasSuffix(summary.ModelName, "search-preview") {
|
||||||
|
summary.WebSearchCallCount = 1
|
||||||
|
summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName)
|
||||||
|
surcharge = surcharge.Add(decimal.NewFromFloat(summary.WebSearchPrice).
|
||||||
|
Div(decimal.NewFromInt(1000)).
|
||||||
|
Mul(dGroupRatio).
|
||||||
|
Mul(dQuotaPerUnit))
|
||||||
|
}
|
||||||
|
|
||||||
|
summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests")
|
||||||
|
if summary.ClaudeWebSearchCallCount > 0 {
|
||||||
|
summary.ClaudeWebSearchPrice = operation_setting.GetToolPrice("web_search")
|
||||||
|
surcharge = surcharge.Add(decimal.NewFromFloat(summary.ClaudeWebSearchPrice).
|
||||||
|
Div(decimal.NewFromInt(1000)).
|
||||||
|
Mul(dGroupRatio).
|
||||||
|
Mul(dQuotaPerUnit).
|
||||||
|
Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount))))
|
||||||
|
}
|
||||||
|
|
||||||
|
if relayInfo.ResponsesUsageInfo != nil {
|
||||||
|
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
|
||||||
|
summary.FileSearchCallCount = fileSearchTool.CallCount
|
||||||
|
summary.FileSearchPrice = operation_setting.GetToolPrice("file_search")
|
||||||
|
surcharge = surcharge.Add(decimal.NewFromFloat(summary.FileSearchPrice).
|
||||||
|
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||||||
|
Div(decimal.NewFromInt(1000)).
|
||||||
|
Mul(dGroupRatio).
|
||||||
|
Mul(dQuotaPerUnit))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.GetBool("image_generation_call") {
|
||||||
|
summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||||
|
surcharge = surcharge.Add(decimal.NewFromFloat(summary.ImageGenerationCallPrice).
|
||||||
|
Mul(dGroupRatio).
|
||||||
|
Mul(dQuotaPerUnit))
|
||||||
|
}
|
||||||
|
|
||||||
|
return surcharge
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTieredTextQuota(relayInfo *relaycommon.RelayInfo, summary textQuotaSummary, tieredQuota int, tieredResult *billingexpr.TieredResult) int {
|
||||||
|
if summary.ToolCallSurchargeQuota.IsZero() {
|
||||||
|
return tieredQuota
|
||||||
|
}
|
||||||
|
|
||||||
|
if tieredResult != nil {
|
||||||
|
if snap := relayInfo.TieredBillingSnapshot; snap != nil {
|
||||||
|
return int(decimal.NewFromFloat(tieredResult.ActualQuotaBeforeGroup).
|
||||||
|
Mul(decimal.NewFromFloat(snap.GroupRatio)).
|
||||||
|
Add(summary.ToolCallSurchargeQuota).
|
||||||
|
Round(0).
|
||||||
|
IntPart())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tieredQuota + int(summary.ToolCallSurchargeQuota.Round(0).IntPart())
|
||||||
|
}
|
||||||
|
|
||||||
func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) textQuotaSummary {
|
func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) textQuotaSummary {
|
||||||
summary := textQuotaSummary{
|
summary := textQuotaSummary{
|
||||||
ModelName: relayInfo.OriginModelName,
|
ModelName: relayInfo.OriginModelName,
|
||||||
@@ -148,52 +224,7 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf
|
|||||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||||
|
|
||||||
ratio := dModelRatio.Mul(dGroupRatio)
|
ratio := dModelRatio.Mul(dGroupRatio)
|
||||||
|
summary.ToolCallSurchargeQuota = calculateTextToolCallSurcharge(ctx, relayInfo, &summary)
|
||||||
var dWebSearchQuota decimal.Decimal
|
|
||||||
if relayInfo.ResponsesUsageInfo != nil {
|
|
||||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
|
|
||||||
summary.WebSearchCallCount = webSearchTool.CallCount
|
|
||||||
summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName)
|
|
||||||
dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice).
|
|
||||||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
|
||||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
|
||||||
}
|
|
||||||
} else if strings.HasSuffix(summary.ModelName, "search-preview") {
|
|
||||||
searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
|
|
||||||
if searchContextSize == "" {
|
|
||||||
searchContextSize = "medium"
|
|
||||||
}
|
|
||||||
summary.WebSearchCallCount = 1
|
|
||||||
summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName)
|
|
||||||
dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice).
|
|
||||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
|
||||||
}
|
|
||||||
|
|
||||||
var dClaudeWebSearchQuota decimal.Decimal
|
|
||||||
summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests")
|
|
||||||
if summary.ClaudeWebSearchCallCount > 0 {
|
|
||||||
summary.ClaudeWebSearchPrice = operation_setting.GetToolPrice("web_search")
|
|
||||||
dClaudeWebSearchQuota = decimal.NewFromFloat(summary.ClaudeWebSearchPrice).
|
|
||||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).
|
|
||||||
Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount)))
|
|
||||||
}
|
|
||||||
|
|
||||||
var dFileSearchQuota decimal.Decimal
|
|
||||||
if relayInfo.ResponsesUsageInfo != nil {
|
|
||||||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
|
|
||||||
summary.FileSearchCallCount = fileSearchTool.CallCount
|
|
||||||
summary.FileSearchPrice = operation_setting.GetToolPrice("file_search")
|
|
||||||
dFileSearchQuota = decimal.NewFromFloat(summary.FileSearchPrice).
|
|
||||||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
|
||||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var dImageGenerationCallQuota decimal.Decimal
|
|
||||||
if ctx.GetBool("image_generation_call") {
|
|
||||||
summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
|
||||||
dImageGenerationCallQuota = decimal.NewFromFloat(summary.ImageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
|
||||||
}
|
|
||||||
|
|
||||||
var audioInputQuota decimal.Decimal
|
var audioInputQuota decimal.Decimal
|
||||||
if !relayInfo.PriceData.UsePrice {
|
if !relayInfo.PriceData.UsePrice {
|
||||||
@@ -242,11 +273,8 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf
|
|||||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio).Add(cachedCreationTokensWithRatio)
|
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio).Add(cachedCreationTokensWithRatio)
|
||||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||||
quotaCalculateDecimal := promptQuota.Add(completionQuota).Mul(ratio)
|
quotaCalculateDecimal := promptQuota.Add(completionQuota).Mul(ratio)
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
quotaCalculateDecimal = quotaCalculateDecimal.Add(summary.ToolCallSurchargeQuota)
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota)
|
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
|
||||||
|
|
||||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||||
for _, otherRatio := range relayInfo.PriceData.OtherRatios {
|
for _, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||||
@@ -260,11 +288,8 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf
|
|||||||
summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart())
|
summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart())
|
||||||
} else {
|
} else {
|
||||||
quotaCalculateDecimal := dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
quotaCalculateDecimal := dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
quotaCalculateDecimal = quotaCalculateDecimal.Add(summary.ToolCallSurchargeQuota)
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota)
|
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
|
||||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||||
for _, otherRatio := range relayInfo.PriceData.OtherRatios {
|
for _, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio))
|
quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio))
|
||||||
@@ -313,7 +338,7 @@ func PostTextConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, us
|
|||||||
tieredOk, tieredQuota, tieredRes := TryTieredSettle(relayInfo, BuildTieredTokenParams(usage, summary.IsClaudeUsageSemantic, tieredUsedVars))
|
tieredOk, tieredQuota, tieredRes := TryTieredSettle(relayInfo, BuildTieredTokenParams(usage, summary.IsClaudeUsageSemantic, tieredUsedVars))
|
||||||
if tieredOk {
|
if tieredOk {
|
||||||
tieredResult = tieredRes
|
tieredResult = tieredRes
|
||||||
summary.Quota = tieredQuota
|
summary.Quota = composeTieredTextQuota(relayInfo, summary, tieredQuota, tieredRes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/QuantumNous/new-api/constant"
|
"github.com/QuantumNous/new-api/constant"
|
||||||
"github.com/QuantumNous/new-api/dto"
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/types"
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
@@ -316,3 +317,125 @@ func TestCalculateTextQuotaSummaryKeepsPrePRClaudeOpenRouterBilling(t *testing.T
|
|||||||
require.Equal(t, 172, summary.PromptTokens)
|
require.Equal(t, 172, summary.PromptTokens)
|
||||||
require.Equal(t, 798, summary.Quota)
|
require.Equal(t, 798, summary.Quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestComposeTieredTextQuotaKeepsToolCallSurcharges(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(w)
|
||||||
|
ctx.Set("image_generation_call", true)
|
||||||
|
ctx.Set("image_generation_call_quality", "low")
|
||||||
|
ctx.Set("image_generation_call_size", "1024x1024")
|
||||||
|
|
||||||
|
relayInfo := &relaycommon.RelayInfo{
|
||||||
|
OriginModelName: "o1",
|
||||||
|
PriceData: types.PriceData{
|
||||||
|
ModelRatio: 1,
|
||||||
|
CompletionRatio: 1,
|
||||||
|
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||||
|
},
|
||||||
|
ResponsesUsageInfo: &relaycommon.ResponsesUsageInfo{
|
||||||
|
BuiltInTools: map[string]*relaycommon.BuildInToolInfo{
|
||||||
|
dto.BuildInToolWebSearchPreview: &relaycommon.BuildInToolInfo{
|
||||||
|
CallCount: 1,
|
||||||
|
},
|
||||||
|
dto.BuildInToolFileSearch: &relaycommon.BuildInToolInfo{
|
||||||
|
CallCount: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
|
||||||
|
BillingMode: "tiered_expr",
|
||||||
|
GroupRatio: 1,
|
||||||
|
EstimatedQuotaBeforeGroup: 1000,
|
||||||
|
},
|
||||||
|
StartTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &dto.Usage{
|
||||||
|
PromptTokens: 100,
|
||||||
|
CompletionTokens: 50,
|
||||||
|
TotalTokens: 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||||
|
quota := composeTieredTextQuota(relayInfo, summary, 1000, &billingexpr.TieredResult{
|
||||||
|
ActualQuotaBeforeGroup: 1000,
|
||||||
|
ActualQuotaAfterGroup: 1000,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, int64(13000), summary.ToolCallSurchargeQuota.Round(0).IntPart())
|
||||||
|
require.Equal(t, 14000, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComposeTieredTextQuotaFallbackKeepsToolCallSurcharges(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(w)
|
||||||
|
ctx.Set("claude_web_search_requests", 2)
|
||||||
|
|
||||||
|
relayInfo := &relaycommon.RelayInfo{
|
||||||
|
OriginModelName: "claude-3-7-sonnet",
|
||||||
|
PriceData: types.PriceData{
|
||||||
|
ModelRatio: 1,
|
||||||
|
CompletionRatio: 1,
|
||||||
|
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1.25},
|
||||||
|
},
|
||||||
|
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
|
||||||
|
BillingMode: "tiered_expr",
|
||||||
|
GroupRatio: 1.25,
|
||||||
|
EstimatedQuotaBeforeGroup: 1000,
|
||||||
|
},
|
||||||
|
StartTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &dto.Usage{
|
||||||
|
PromptTokens: 100,
|
||||||
|
CompletionTokens: 50,
|
||||||
|
TotalTokens: 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||||
|
quota := composeTieredTextQuota(relayInfo, summary, 1250, nil)
|
||||||
|
|
||||||
|
require.Equal(t, int64(12500), summary.ToolCallSurchargeQuota.Round(0).IntPart())
|
||||||
|
require.Equal(t, 13750, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComposeTieredTextQuotaErrorFallbackUsesPreConsumedQuota(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(w)
|
||||||
|
ctx.Set("claude_web_search_requests", 2)
|
||||||
|
|
||||||
|
relayInfo := &relaycommon.RelayInfo{
|
||||||
|
OriginModelName: "claude-3-7-sonnet",
|
||||||
|
PriceData: types.PriceData{
|
||||||
|
ModelRatio: 1,
|
||||||
|
CompletionRatio: 1,
|
||||||
|
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1.25},
|
||||||
|
},
|
||||||
|
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
|
||||||
|
BillingMode: "tiered_expr",
|
||||||
|
GroupRatio: 1.25,
|
||||||
|
EstimatedQuotaBeforeGroup: 1000,
|
||||||
|
},
|
||||||
|
StartTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &dto.Usage{
|
||||||
|
PromptTokens: 100,
|
||||||
|
CompletionTokens: 50,
|
||||||
|
TotalTokens: 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||||
|
|
||||||
|
// tieredResult=nil simulates a settlement error where TryTieredSettle
|
||||||
|
// falls back to FinalPreConsumedQuota (2000), which differs from
|
||||||
|
// EstimatedQuotaBeforeGroup * GroupRatio (1250).
|
||||||
|
preConsumedFallback := 2000
|
||||||
|
quota := composeTieredTextQuota(relayInfo, summary, preConsumedFallback, nil)
|
||||||
|
|
||||||
|
require.Equal(t, int64(12500), summary.ToolCallSurchargeQuota.Round(0).IntPart())
|
||||||
|
require.Equal(t, 14500, quota)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user