diff --git a/.gitignore b/.gitignore index 54fa8311..b5d65bcf 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,6 @@ data/ .gomodcache/ .gocache-temp .gopath + +token_estimator_test.go +skills-lock.json \ No newline at end of file diff --git a/go.mod b/go.mod index 2f28f781..43808976 100644 --- a/go.mod +++ b/go.mod @@ -75,6 +75,7 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/expr-lang/expr v1.17.8 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect diff --git a/go.sum b/go.sum index 74298929..eb2be8a4 100644 --- a/go.sum +++ b/go.sum @@ -68,6 +68,8 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/expr-lang/expr v1.17.8 h1:W1loDTT+0PQf5YteHSTpju2qfUfNoBt4yw9+wOEU9VM= +github.com/expr-lang/expr v1.17.8/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= diff --git a/model/option.go b/model/option.go index 697e77df..b768addc 100644 --- a/model/option.go +++ b/model/option.go @@ -7,6 +7,7 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/setting" + "github.com/QuantumNous/new-api/setting/billing_setting" "github.com/QuantumNous/new-api/setting/config" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/setting/performance_setting" @@ -121,6 +122,8 @@ func InitOptionMap() { common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString() common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString() + common.OptionMap["ModelBillingMode"] = billing_setting.BillingMode2JSONString() + common.OptionMap["ModelBillingExpr"] = billing_setting.BillingExpr2JSONString() common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString() common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink @@ -436,6 +439,10 @@ func updateOptionMap(key string, value string) (err error) { err = ratio_setting.UpdateAudioRatioByJSONString(value) case "AudioCompletionRatio": err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value) + case "ModelBillingMode": + err = billing_setting.UpdateBillingModeByJSONString(value) + case "ModelBillingExpr": + err = billing_setting.UpdateBillingExprByJSONString(value) case "TopUpLink": common.TopUpLink = value //case "ChatLink": diff --git a/pkg/billingexpr/billingexpr_test.go b/pkg/billingexpr/billingexpr_test.go new file mode 100644 index 00000000..34ec0fc0 --- /dev/null +++ b/pkg/billingexpr/billingexpr_test.go @@ -0,0 +1,933 @@ +package billingexpr_test + +import ( + "math" + "math/rand" + "testing" + + "github.com/QuantumNous/new-api/pkg/billingexpr" +) + +// --------------------------------------------------------------------------- +// Claude-style: fixed tiers, input > 200K changes both input & output price +// --------------------------------------------------------------------------- + +const claudeExpr = `p <= 200000 ? tier("standard", p * 1.5 + c * 7.5) : tier("long_context", p * 3.0 + c * 11.25)` + +func TestClaude_StandardTier(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{P: 100000, C: 5000}) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard") + } +} + +func TestClaude_LongContextTier(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{P: 300000, C: 10000}) + if err != nil { + t.Fatal(err) + } + want := 300000*3.0 + 10000*11.25 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "long_context" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "long_context") + } +} + +func TestClaude_BoundaryExact(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{P: 200000, C: 1000}) + if err != nil { + t.Fatal(err) + } + want := 200000*1.5 + 1000*7.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard") + } +} + +// --------------------------------------------------------------------------- +// GLM-style: multi-condition tiers with both input and output dimensions +// --------------------------------------------------------------------------- + +const glmExpr = ` +( + p < 32000 && c < 200 ? tier("tier1_short", (p)*2 + c*8) : + p < 32000 && c >= 200 ? tier("tier2_long_output", (p)*3 + c*14) : + tier("tier3_long_input", (p)*4 + c*16) +) / 1000000 +` + +func TestGLM_Tier1(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(glmExpr, billingexpr.TokenParams{P: 15000, C: 100}) + if err != nil { + t.Fatal(err) + } + want := (15000.0*2 + 100.0*8) / 1000000 + if math.Abs(cost-want) > 1e-10 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "tier1_short" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "tier1_short") + } +} + +func TestGLM_Tier2(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(glmExpr, billingexpr.TokenParams{P: 15000, C: 500}) + if err != nil { + t.Fatal(err) + } + want := (15000.0*3 + 500.0*14) / 1000000 + if math.Abs(cost-want) > 1e-10 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "tier2_long_output" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "tier2_long_output") + } +} + +func TestGLM_Tier3(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(glmExpr, billingexpr.TokenParams{P: 50000, C: 100}) + if err != nil { + t.Fatal(err) + } + want := (50000.0*4 + 100.0*16) / 1000000 + if math.Abs(cost-want) > 1e-10 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "tier3_long_input" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "tier3_long_input") + } +} + +// --------------------------------------------------------------------------- +// Simple flat-rate (no tier() call) +// --------------------------------------------------------------------------- + +func TestSimpleExpr_NoTier(t *testing.T) { + cost, trace, err := billingexpr.RunExpr("p * 0.5 + c * 1.0", billingexpr.TokenParams{P: 1000, C: 500}) + if err != nil { + t.Fatal(err) + } + want := 1000*0.5 + 500*1.0 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "" { + t.Errorf("tier should be empty, got %q", trace.MatchedTier) + } +} + +// --------------------------------------------------------------------------- +// Math helper functions +// --------------------------------------------------------------------------- + +func TestMathHelpers(t *testing.T) { + cost, _, err := billingexpr.RunExpr("max(p, c) * 0.5 + min(p, c) * 0.1", billingexpr.TokenParams{P: 300, C: 500}) + if err != nil { + t.Fatal(err) + } + want := 500*0.5 + 300*0.1 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestRequestProbeHelpers(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `prompt_tokens * 0.5 + completion_tokens * 1.0 * (param("service_tier") == "fast" ? 2 : 1)`, + billingexpr.TokenParams{P: 1000, C: 500}, + billingexpr.RequestInput{ + Body: []byte(`{"service_tier":"fast"}`), + }, + ) + if err != nil { + t.Fatal(err) + } + want := 1000*0.5 + 500*1.0*2 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestHeaderProbeHelper(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `p * 0.5 + c * 1.0 * (has(header("anthropic-beta"), "fast-mode") ? 2 : 1)`, + billingexpr.TokenParams{P: 1000, C: 500}, + billingexpr.RequestInput{ + Headers: map[string]string{ + "Anthropic-Beta": "fast-mode-2026-02-01", + }, + }, + ) + if err != nil { + t.Fatal(err) + } + want := 1000*0.5 + 500*1.0*2 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestParamProbeNestedBool(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `p * (param("stream_options.fast_mode") == true ? 1.5 : 1.0)`, + billingexpr.TokenParams{P: 100}, + billingexpr.RequestInput{ + Body: []byte(`{"stream_options":{"fast_mode":true}}`), + }, + ) + if err != nil { + t.Fatal(err) + } + want := 150.0 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestParamProbeArrayLength(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `p * (param("messages.#") > 20 ? 1.2 : 1.0)`, + billingexpr.TokenParams{P: 100}, + billingexpr.RequestInput{ + Body: []byte(`{"messages":[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]}`), + }, + ) + if err != nil { + t.Fatal(err) + } + want := 120.0 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestRequestProbeMissingFieldReturnsNil(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `param("missing.value") == nil ? 2 : 1`, + billingexpr.TokenParams{}, + billingexpr.RequestInput{ + Body: []byte(`{"service_tier":"standard"}`), + }, + ) + if err != nil { + t.Fatal(err) + } + if cost != 2 { + t.Errorf("cost = %f, want 2", cost) + } +} + +func TestRequestProbeMultipleRulesMultiply(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `(param("service_tier") == "fast" ? 2 : 1) * (has(header("anthropic-beta"), "fast-mode-2026-02-01") ? 2.5 : 1)`, + billingexpr.TokenParams{}, + billingexpr.RequestInput{ + Headers: map[string]string{ + "Anthropic-Beta": "fast-mode-2026-02-01", + }, + Body: []byte(`{"service_tier":"fast"}`), + }, + ) + if err != nil { + t.Fatal(err) + } + if math.Abs(cost-5) > 1e-6 { + t.Errorf("cost = %f, want 5", cost) + } +} + +func TestCeilFloor(t *testing.T) { + cost, _, err := billingexpr.RunExpr("ceil(p / 1000) * 0.5", billingexpr.TokenParams{P: 1500}) + if err != nil { + t.Fatal(err) + } + want := math.Ceil(1500.0/1000) * 0.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +// --------------------------------------------------------------------------- +// Zero tokens +// --------------------------------------------------------------------------- + +func TestZeroTokens(t *testing.T) { + cost, _, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{}) + if err != nil { + t.Fatal(err) + } + if cost != 0 { + t.Errorf("cost should be 0 for zero tokens, got %f", cost) + } +} + +// --------------------------------------------------------------------------- +// Rounding +// --------------------------------------------------------------------------- + +func TestQuotaRound(t *testing.T) { + tests := []struct { + in float64 + want int + }{ + {0, 0}, + {0.4, 0}, + {0.5, 1}, + {0.6, 1}, + {1.5, 2}, + {-0.5, -1}, + {-0.6, -1}, + {999.4999, 999}, + {999.5, 1000}, + {1e9 + 0.5, 1e9 + 1}, + } + for _, tt := range tests { + got := billingexpr.QuotaRound(tt.in) + if got != tt.want { + t.Errorf("QuotaRound(%f) = %d, want %d", tt.in, got, tt.want) + } + } +} + +// --------------------------------------------------------------------------- +// Settlement +// --------------------------------------------------------------------------- + +func TestComputeTieredQuota_Basic(t *testing.T) { + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: claudeExpr, + ExprHash: billingexpr.ExprHashString(claudeExpr), + GroupRatio: 1.0, + EstimatedPromptTokens: 100000, + EstimatedCompletionTokens: 5000, + EstimatedQuotaBeforeGroup: 100000*1.5 + 5000*7.5, + EstimatedQuotaAfterGroup: billingexpr.QuotaRound(100000*1.5 + 5000*7.5), + EstimatedTier: "standard", + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 300000, C: 10000}) + if err != nil { + t.Fatal(err) + } + + wantBefore := 300000*3.0 + 10000*11.25 + if math.Abs(result.ActualQuotaBeforeGroup-wantBefore) > 1e-6 { + t.Errorf("before group: got %f, want %f", result.ActualQuotaBeforeGroup, wantBefore) + } + if result.MatchedTier != "long_context" { + t.Errorf("tier = %q, want %q", result.MatchedTier, "long_context") + } + if !result.CrossedTier { + t.Error("expected crossed_tier=true (estimated standard, actual long_context)") + } +} + +func TestComputeTieredQuota_SameTier(t *testing.T) { + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: claudeExpr, + ExprHash: billingexpr.ExprHashString(claudeExpr), + GroupRatio: 1.5, + EstimatedPromptTokens: 50000, + EstimatedCompletionTokens: 1000, + EstimatedQuotaBeforeGroup: 50000*1.5 + 1000*7.5, + EstimatedQuotaAfterGroup: billingexpr.QuotaRound((50000*1.5 + 1000*7.5) * 1.5), + EstimatedTier: "standard", + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 80000, C: 2000}) + if err != nil { + t.Fatal(err) + } + + wantBefore := 80000*1.5 + 2000*7.5 + wantAfter := billingexpr.QuotaRound(wantBefore * 1.5) + if result.ActualQuotaAfterGroup != wantAfter { + t.Errorf("after group: got %d, want %d", result.ActualQuotaAfterGroup, wantAfter) + } + if result.CrossedTier { + t.Error("expected crossed_tier=false (both standard)") + } +} + +// --------------------------------------------------------------------------- +// Compile errors +// --------------------------------------------------------------------------- + +func TestCompileError(t *testing.T) { + _, _, err := billingexpr.RunExpr("invalid +-+ syntax", billingexpr.TokenParams{}) + if err == nil { + t.Error("expected compile error") + } +} + +// --------------------------------------------------------------------------- +// Compile Cache +// --------------------------------------------------------------------------- + +func TestCompileCache_SameResult(t *testing.T) { + r1, _, err := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + r2, _, err := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + if r1 != r2 { + t.Errorf("cached and uncached results differ: %f != %f", r1, r2) + } +} + +func TestInvalidateCache(t *testing.T) { + billingexpr.InvalidateCache() + r1, _, _ := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100}) + billingexpr.InvalidateCache() + r2, _, _ := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100}) + if r1 != r2 { + t.Errorf("post-invalidate results differ: %f != %f", r1, r2) + } +} + +// --------------------------------------------------------------------------- +// Hash +// --------------------------------------------------------------------------- + +func TestExprHashString_Deterministic(t *testing.T) { + h1 := billingexpr.ExprHashString("p * 0.5") + h2 := billingexpr.ExprHashString("p * 0.5") + if h1 != h2 { + t.Error("hash should be deterministic") + } + h3 := billingexpr.ExprHashString("p * 0.6") + if h1 == h3 { + t.Error("different expressions should have different hashes") + } +} + +// --------------------------------------------------------------------------- +// Cache variables: present +// --------------------------------------------------------------------------- + +const claudeWithCacheExpr = `p <= 200000 ? tier("standard", p * 1.5 + c * 7.5 + cr * 0.15 + cc * 1.875) : tier("long_context", p * 3.0 + c * 11.25 + cr * 0.3 + cc * 3.75)` + +func TestCachePresent_StandardTier(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 50000, CC: 10000} + cost, trace, err := billingexpr.RunExpr(claudeWithCacheExpr, params) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + 50000*0.15 + 10000*1.875 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard") + } +} + +func TestCachePresent_LongContextTier(t *testing.T) { + params := billingexpr.TokenParams{P: 300000, C: 10000, CR: 100000, CC: 20000} + cost, trace, err := billingexpr.RunExpr(claudeWithCacheExpr, params) + if err != nil { + t.Fatal(err) + } + want := 300000*3.0 + 10000*11.25 + 100000*0.3 + 20000*3.75 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "long_context" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "long_context") + } +} + +// --------------------------------------------------------------------------- +// Cache variables: absent (all zero) — same expression still works +// --------------------------------------------------------------------------- + +func TestCacheAbsent_ZeroCacheTokens(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 5000} + cost, trace, err := billingexpr.RunExpr(claudeWithCacheExpr, params) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f (cache terms should be 0)", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard") + } +} + +// --------------------------------------------------------------------------- +// Mixed cache fields: cc and cc1h non-zero +// --------------------------------------------------------------------------- + +const claudeCacheSplitExpr = `tier("default", p * 1.5 + c * 7.5 + cr * 0.15 + cc * 2.0 + cc1h * 3.0)` + +func TestMixedCacheFields(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 10000, CC: 5000, CC1h: 2000} + cost, _, err := billingexpr.RunExpr(claudeCacheSplitExpr, params) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + 10000*0.15 + 5000*2.0 + 2000*3.0 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestMixedCacheFields_AllCacheZero(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 5000} + cost, _, err := billingexpr.RunExpr(claudeCacheSplitExpr, params) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f (all cache zero)", cost, want) + } +} + +// --------------------------------------------------------------------------- +// Backward compatibility: p+c only expressions still work with TokenParams +// --------------------------------------------------------------------------- + +func TestBackwardCompat_OldExprWithTokenParams(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 99999, CC: 88888} + cost, trace, err := billingexpr.RunExpr(claudeExpr, params) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f (old expr ignores cache fields)", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard") + } +} + +// --------------------------------------------------------------------------- +// Settlement with cache tokens +// --------------------------------------------------------------------------- + +func TestComputeTieredQuota_WithCache(t *testing.T) { + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: claudeWithCacheExpr, + ExprHash: billingexpr.ExprHashString(claudeWithCacheExpr), + GroupRatio: 1.0, + EstimatedPromptTokens: 100000, + EstimatedCompletionTokens: 5000, + EstimatedQuotaBeforeGroup: 100000*1.5 + 5000*7.5, + EstimatedQuotaAfterGroup: billingexpr.QuotaRound(100000*1.5 + 5000*7.5), + EstimatedTier: "standard", + } + + params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 50000, CC: 10000} + result, err := billingexpr.ComputeTieredQuota(snap, params) + if err != nil { + t.Fatal(err) + } + + wantBefore := 100000*1.5 + 5000*7.5 + 50000*0.15 + 10000*1.875 + if math.Abs(result.ActualQuotaBeforeGroup-wantBefore) > 1e-6 { + t.Errorf("before group: got %f, want %f", result.ActualQuotaBeforeGroup, wantBefore) + } + if result.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", result.MatchedTier, "standard") + } + if result.CrossedTier { + t.Error("expected crossed_tier=false (same tier)") + } +} + +func TestComputeTieredQuota_WithCacheCrossTier(t *testing.T) { + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: claudeWithCacheExpr, + ExprHash: billingexpr.ExprHashString(claudeWithCacheExpr), + GroupRatio: 2.0, + EstimatedPromptTokens: 100000, + EstimatedCompletionTokens: 5000, + EstimatedQuotaBeforeGroup: 100000*1.5 + 5000*7.5, + EstimatedQuotaAfterGroup: billingexpr.QuotaRound((100000*1.5 + 5000*7.5) * 2.0), + EstimatedTier: "standard", + } + + params := billingexpr.TokenParams{P: 300000, C: 10000, CR: 50000, CC: 10000} + result, err := billingexpr.ComputeTieredQuota(snap, params) + if err != nil { + t.Fatal(err) + } + + wantBefore := 300000*3.0 + 10000*11.25 + 50000*0.3 + 10000*3.75 + wantAfter := billingexpr.QuotaRound(wantBefore * 2.0) + if math.Abs(result.ActualQuotaBeforeGroup-wantBefore) > 1e-6 { + t.Errorf("before group: got %f, want %f", result.ActualQuotaBeforeGroup, wantBefore) + } + if result.ActualQuotaAfterGroup != wantAfter { + t.Errorf("after group: got %d, want %d", result.ActualQuotaAfterGroup, wantAfter) + } + if !result.CrossedTier { + t.Error("expected crossed_tier=true (estimated standard, actual long_context)") + } +} + +// --------------------------------------------------------------------------- +// Fuzz: random p/c/cache, verify non-negative result +// --------------------------------------------------------------------------- + +func TestFuzz_NonNegativeResults(t *testing.T) { + exprs := []string{ + claudeExpr, + claudeWithCacheExpr, + glmExpr, + "p * 0.5 + c * 1.0", + "max(p, c) * 0.1", + "p * 0.5 + cr * 0.1 + cc * 0.2", + } + + rng := rand.New(rand.NewSource(42)) + + for _, exprStr := range exprs { + for i := 0; i < 500; i++ { + params := billingexpr.TokenParams{ + P: math.Round(rng.Float64() * 1000000), + C: math.Round(rng.Float64() * 500000), + CR: math.Round(rng.Float64() * 200000), + CC: math.Round(rng.Float64() * 50000), + CC1h: math.Round(rng.Float64() * 10000), + } + + cost, _, err := billingexpr.RunExpr(exprStr, params) + if err != nil { + t.Fatalf("expr=%q params=%+v: %v", exprStr, params, err) + } + if cost < 0 { + t.Errorf("expr=%q params=%+v: negative cost %f", exprStr, params, cost) + } + } + } +} + +func TestFuzz_SettlementConsistency(t *testing.T) { + rng := rand.New(rand.NewSource(99)) + + for i := 0; i < 200; i++ { + estParams := billingexpr.TokenParams{ + P: math.Round(rng.Float64() * 500000), + C: math.Round(rng.Float64() * 100000), + CR: math.Round(rng.Float64() * 100000), + CC: math.Round(rng.Float64() * 30000), + } + actParams := billingexpr.TokenParams{ + P: math.Round(rng.Float64() * 500000), + C: math.Round(rng.Float64() * 100000), + CR: math.Round(rng.Float64() * 100000), + CC: math.Round(rng.Float64() * 30000), + } + groupRatio := 0.5 + rng.Float64()*2.0 + + estCost, estTrace, _ := billingexpr.RunExpr(claudeWithCacheExpr, estParams) + + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: claudeWithCacheExpr, + ExprHash: billingexpr.ExprHashString(claudeWithCacheExpr), + GroupRatio: groupRatio, + EstimatedPromptTokens: int(estParams.P), + EstimatedCompletionTokens: int(estParams.C), + EstimatedQuotaBeforeGroup: estCost, + EstimatedQuotaAfterGroup: billingexpr.QuotaRound(estCost * groupRatio), + EstimatedTier: estTrace.MatchedTier, + } + + result, err := billingexpr.ComputeTieredQuota(snap, actParams) + if err != nil { + t.Fatalf("iter %d: %v", i, err) + } + + directCost, _, _ := billingexpr.RunExpr(claudeWithCacheExpr, actParams) + directQuota := billingexpr.QuotaRound(directCost * groupRatio) + + if result.ActualQuotaAfterGroup != directQuota { + t.Errorf("iter %d: settlement %d != direct %d", i, result.ActualQuotaAfterGroup, directQuota) + } + } +} + +// --------------------------------------------------------------------------- +// Settlement-level tests for ComputeTieredQuota +// --------------------------------------------------------------------------- + +func TestComputeTieredQuota_BasicSettlement(t *testing.T) { + exprStr := `tier("default", p + c)` + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 3000, C: 2000}) + if err != nil { + t.Fatal(err) + } + if math.Abs(result.ActualQuotaBeforeGroup-5000) > 1e-6 { + t.Errorf("before group = %f, want 5000", result.ActualQuotaBeforeGroup) + } + if result.ActualQuotaAfterGroup != 5000 { + t.Errorf("after group = %d, want 5000", result.ActualQuotaAfterGroup) + } + if result.MatchedTier != "default" { + t.Errorf("tier = %q, want default", result.MatchedTier) + } +} + +func TestComputeTieredQuota_WithGroupRatio(t *testing.T) { + exprStr := `tier("default", p + c)` + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 2.0, + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 1000, C: 500}) + if err != nil { + t.Fatal(err) + } + // cost = 1500, after group = round(1500 * 2.0) = 3000 + if result.ActualQuotaAfterGroup != 3000 { + t.Errorf("after group = %d, want 3000", result.ActualQuotaAfterGroup) + } +} + +func TestComputeTieredQuota_ZeroTokens(t *testing.T) { + exprStr := `tier("default", p * 2 + c * 10)` + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{}) + if err != nil { + t.Fatal(err) + } + if result.ActualQuotaAfterGroup != 0 { + t.Errorf("after group = %d, want 0", result.ActualQuotaAfterGroup) + } +} + +func TestComputeTieredQuota_RoundingEdge(t *testing.T) { + exprStr := `tier("default", p * 0.5)` // 3 * 0.5 = 1.5 -> round to 2 + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 3}) + if err != nil { + t.Fatal(err) + } + // 3 * 0.5 = 1.5, round(1.5) = 2 + if result.ActualQuotaAfterGroup != 2 { + t.Errorf("after group = %d, want 2 (round 1.5 up)", result.ActualQuotaAfterGroup) + } +} + +func TestComputeTieredQuota_RoundingEdgeDown(t *testing.T) { + exprStr := `tier("default", p * 0.4)` // 3 * 0.4 = 1.2 -> round to 1 + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 3}) + if err != nil { + t.Fatal(err) + } + // 3 * 0.4 = 1.2, round(1.2) = 1 + if result.ActualQuotaAfterGroup != 1 { + t.Errorf("after group = %d, want 1 (round 1.2 down)", result.ActualQuotaAfterGroup) + } +} + +func TestComputeTieredQuotaWithRequest_ProbeAffectsQuota(t *testing.T) { + exprStr := `param("fast") == true ? tier("fast", p * 4) : tier("normal", p * 2)` + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + EstimatedTier: "normal", + } + + // Without request: normal tier + r1, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 1000}) + if err != nil { + t.Fatal(err) + } + if r1.ActualQuotaAfterGroup != 2000 { + t.Errorf("normal = %d, want 2000", r1.ActualQuotaAfterGroup) + } + + // With request: fast tier + r2, err := billingexpr.ComputeTieredQuotaWithRequest(snap, billingexpr.TokenParams{P: 1000}, billingexpr.RequestInput{ + Body: []byte(`{"fast":true}`), + }) + if err != nil { + t.Fatal(err) + } + if r2.ActualQuotaAfterGroup != 4000 { + t.Errorf("fast = %d, want 4000", r2.ActualQuotaAfterGroup) + } + if !r2.CrossedTier { + t.Error("expected CrossedTier = true when probe changes tier") + } +} + +func TestComputeTieredQuota_BoundaryTierCrossing(t *testing.T) { + exprStr := `p <= 100000 ? tier("small", p * 1) : tier("large", p * 2)` + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + EstimatedTier: "small", + } + + // At boundary + r1, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 100000}) + if err != nil { + t.Fatal(err) + } + if r1.MatchedTier != "small" { + t.Errorf("at boundary: tier = %s, want small", r1.MatchedTier) + } + if r1.ActualQuotaAfterGroup != 100000 { + t.Errorf("at boundary: quota = %d, want 100000", r1.ActualQuotaAfterGroup) + } + + // Past boundary + r2, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 100001}) + if err != nil { + t.Fatal(err) + } + if r2.MatchedTier != "large" { + t.Errorf("past boundary: tier = %s, want large", r2.MatchedTier) + } + if r2.ActualQuotaAfterGroup != 200002 { + t.Errorf("past boundary: quota = %d, want 200002", r2.ActualQuotaAfterGroup) + } + if !r2.CrossedTier { + t.Error("expected CrossedTier = true") + } +} + +// --------------------------------------------------------------------------- +// Time function tests +// --------------------------------------------------------------------------- + +func TestTimeFunctions_ValidTimezone(t *testing.T) { + exprStr := `tier("default", p) * (hour("UTC") >= 0 ? 1 : 1)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + if cost != 100 { + t.Errorf("cost = %f, want 100", cost) + } +} + +func TestTimeFunctions_AllFunctionsCompile(t *testing.T) { + exprStr := `tier("default", p) * (hour("Asia/Shanghai") >= 0 ? 1 : 1) * (minute("UTC") >= 0 ? 1 : 1) * (weekday("UTC") >= 0 ? 1 : 1) * (month("UTC") >= 1 ? 1 : 1) * (day("UTC") >= 1 ? 1 : 1)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 500}) + if err != nil { + t.Fatal(err) + } + if cost != 500 { + t.Errorf("cost = %f, want 500", cost) + } +} + +func TestTimeFunctions_InvalidTimezone(t *testing.T) { + exprStr := `tier("default", p) * (hour("Invalid/Zone") >= 0 ? 1 : 2)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + // Invalid timezone falls back to UTC; hour is 0-23, so condition is always true + if cost != 100 { + t.Errorf("cost = %f, want 100 (fallback to UTC)", cost) + } +} + +func TestTimeFunctions_EmptyTimezone(t *testing.T) { + exprStr := `tier("default", p) * (hour("") >= 0 ? 1 : 2)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + if cost != 100 { + t.Errorf("cost = %f, want 100 (empty tz -> UTC)", cost) + } +} + +func TestTimeFunctions_NightDiscountPattern(t *testing.T) { + exprStr := `tier("default", p * 2 + c * 10) * (hour("UTC") >= 21 || hour("UTC") < 6 ? 0.5 : 1)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 1000, C: 500}) + if err != nil { + t.Fatal(err) + } + // Base = 1000*2 + 500*10 = 7000; multiplier is either 0.5 or 1 depending on current UTC hour + if cost != 7000 && cost != 3500 { + t.Errorf("cost = %f, want 7000 or 3500", cost) + } +} + +func TestTimeFunctions_WeekdayRange(t *testing.T) { + exprStr := `tier("default", p) * (weekday("UTC") >= 0 && weekday("UTC") <= 6 ? 1 : 999)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + // weekday is always 0-6, so multiplier is always 1 + if cost != 100 { + t.Errorf("cost = %f, want 100", cost) + } +} + +func TestTimeFunctions_MonthDayPattern(t *testing.T) { + exprStr := `tier("default", p) * (month("Asia/Shanghai") == 1 && day("Asia/Shanghai") == 1 ? 0.5 : 1)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 1000}) + if err != nil { + t.Fatal(err) + } + // Either 1000 (not Jan 1) or 500 (Jan 1) — both are valid + if cost != 1000 && cost != 500 { + t.Errorf("cost = %f, want 1000 or 500", cost) + } +} diff --git a/pkg/billingexpr/compile.go b/pkg/billingexpr/compile.go new file mode 100644 index 00000000..4a8b61e9 --- /dev/null +++ b/pkg/billingexpr/compile.go @@ -0,0 +1,91 @@ +package billingexpr + +import ( + "fmt" + "math" + "sync" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" +) + +const maxCacheSize = 256 + +var ( + cacheMu sync.RWMutex + cache = make(map[string]*vm.Program, 64) +) + +// compileEnvPrototype is the type-checking prototype used at compile time. +// It declares the shape of the environment that RunExpr will provide. +// The tier() function is a no-op placeholder here; the real one with +// side-channel tracing is injected at runtime. +var compileEnvPrototype = map[string]interface{}{ + "p": float64(0), + "c": float64(0), + "cr": float64(0), + "cc": float64(0), + "cc1h": float64(0), + "prompt_tokens": float64(0), + "completion_tokens": float64(0), + "cache_read_tokens": float64(0), + "cache_create_tokens": float64(0), + "cache_create_1h_tokens": float64(0), + "tier": func(string, float64) float64 { return 0 }, + "header": func(string) string { return "" }, + "param": func(string) interface{} { return nil }, + "has": func(interface{}, string) bool { return false }, + "hour": func(string) int { return 0 }, + "minute": func(string) int { return 0 }, + "weekday": func(string) int { return 0 }, + "month": func(string) int { return 0 }, + "day": func(string) int { return 0 }, + "max": math.Max, + "min": math.Min, + "abs": math.Abs, + "ceil": math.Ceil, + "floor": math.Floor, +} + +// CompileFromCache compiles an expression string, using a cached program when +// available. The cache is keyed by the SHA-256 hex digest of the expression. +func CompileFromCache(exprStr string) (*vm.Program, error) { + return compileFromCacheByHash(exprStr, ExprHashString(exprStr)) +} + +// CompileFromCacheByHash is like CompileFromCache but accepts a pre-computed +// hash, useful when the caller already has the BillingSnapshot.ExprHash. +func CompileFromCacheByHash(exprStr, hash string) (*vm.Program, error) { + return compileFromCacheByHash(exprStr, hash) +} + +func compileFromCacheByHash(exprStr, hash string) (*vm.Program, error) { + cacheMu.RLock() + if prog, ok := cache[hash]; ok { + cacheMu.RUnlock() + return prog, nil + } + cacheMu.RUnlock() + + prog, err := expr.Compile(exprStr, expr.Env(compileEnvPrototype), expr.AsFloat64()) + if err != nil { + return nil, fmt.Errorf("expr compile error: %w", err) + } + + cacheMu.Lock() + if len(cache) >= maxCacheSize { + cache = make(map[string]*vm.Program, 64) + } + cache[hash] = prog + cacheMu.Unlock() + + return prog, nil +} + +// InvalidateCache clears the compiled-expression cache. +// Called when billing rules are updated. +func InvalidateCache() { + cacheMu.Lock() + cache = make(map[string]*vm.Program, 64) + cacheMu.Unlock() +} diff --git a/pkg/billingexpr/round.go b/pkg/billingexpr/round.go new file mode 100644 index 00000000..35a5534a --- /dev/null +++ b/pkg/billingexpr/round.go @@ -0,0 +1,10 @@ +package billingexpr + +import "math" + +// QuotaRound converts a float64 quota value to int using half-away-from-zero +// rounding. Every tiered billing path (pre-consume, settlement, breakdown +// validation, log fields) MUST use this function to avoid +-1 discrepancies. +func QuotaRound(f float64) int { + return int(math.Round(f)) +} diff --git a/pkg/billingexpr/run.go b/pkg/billingexpr/run.go new file mode 100644 index 00000000..267c5af2 --- /dev/null +++ b/pkg/billingexpr/run.go @@ -0,0 +1,139 @@ +package billingexpr + +import ( + "fmt" + "math" + "strings" + "time" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + "github.com/tidwall/gjson" +) + +// RunExpr compiles (with cache) and executes an expression string. +// The environment exposes: +// - p, c — prompt / completion tokens +// - cr, cc, cc1h — cache read / creation / creation-1h tokens +// - tier(name, value) — trace callback that records which tier matched +// - max, min, abs, ceil, floor — standard math helpers +// +// Returns the resulting float64 quota (before group ratio) and a TraceResult +// with side-channel info captured by tier() during execution. +func RunExpr(exprStr string, params TokenParams) (float64, TraceResult, error) { + return RunExprWithRequest(exprStr, params, RequestInput{}) +} + +func RunExprWithRequest(exprStr string, params TokenParams, request RequestInput) (float64, TraceResult, error) { + prog, err := CompileFromCache(exprStr) + if err != nil { + return 0, TraceResult{}, err + } + return runProgram(prog, params, request) +} + +// RunExprByHash is like RunExpr but accepts a pre-computed hash for the cache +// lookup, avoiding a redundant SHA-256 computation when the caller already +// holds BillingSnapshot.ExprHash. +func RunExprByHash(exprStr, hash string, params TokenParams) (float64, TraceResult, error) { + return RunExprByHashWithRequest(exprStr, hash, params, RequestInput{}) +} + +func RunExprByHashWithRequest(exprStr, hash string, params TokenParams, request RequestInput) (float64, TraceResult, error) { + prog, err := CompileFromCacheByHash(exprStr, hash) + if err != nil { + return 0, TraceResult{}, err + } + return runProgram(prog, params, request) +} + +func runProgram(prog *vm.Program, params TokenParams, request RequestInput) (float64, TraceResult, error) { + trace := TraceResult{} + headers := normalizeHeaders(request.Headers) + + env := map[string]interface{}{ + "p": params.P, + "c": params.C, + "cr": params.CR, + "cc": params.CC, + "cc1h": params.CC1h, + "prompt_tokens": params.P, + "completion_tokens": params.C, + "cache_read_tokens": params.CR, + "cache_create_tokens": params.CC, + "cache_create_1h_tokens": params.CC1h, + "tier": func(name string, value float64) float64 { + trace.MatchedTier = name + trace.Cost = value + return value + }, + "header": func(key string) string { + return headers[strings.ToLower(strings.TrimSpace(key))] + }, + "param": func(path string) interface{} { + path = strings.TrimSpace(path) + if path == "" || len(request.Body) == 0 { + return nil + } + result := gjson.GetBytes(request.Body, path) + if !result.Exists() { + return nil + } + return result.Value() + }, + "has": func(source interface{}, substr string) bool { + if source == nil || substr == "" { + return false + } + return strings.Contains(fmt.Sprint(source), substr) + }, + "hour": func(tz string) int { return timeInZone(tz).Hour() }, + "minute": func(tz string) int { return timeInZone(tz).Minute() }, + "weekday": func(tz string) int { return int(timeInZone(tz).Weekday()) }, + "month": func(tz string) int { return int(timeInZone(tz).Month()) }, + "day": func(tz string) int { return timeInZone(tz).Day() }, + "max": math.Max, + "min": math.Min, + "abs": math.Abs, + "ceil": math.Ceil, + "floor": math.Floor, + } + + out, err := expr.Run(prog, env) + if err != nil { + return 0, trace, fmt.Errorf("expr run error: %w", err) + } + f, ok := out.(float64) + if !ok { + return 0, trace, fmt.Errorf("expr result is %T, want float64", out) + } + return f, trace, nil +} + +func timeInZone(tz string) time.Time { + tz = strings.TrimSpace(tz) + if tz == "" { + return time.Now().UTC() + } + loc, err := time.LoadLocation(tz) + if err != nil { + return time.Now().UTC() + } + return time.Now().In(loc) +} + +func normalizeHeaders(headers map[string]string) map[string]string { + if len(headers) == 0 { + return map[string]string{} + } + normalized := make(map[string]string, len(headers)) + for key, value := range headers { + k := strings.ToLower(strings.TrimSpace(key)) + v := strings.TrimSpace(value) + if k == "" || v == "" { + continue + } + normalized[k] = v + } + return normalized +} diff --git a/pkg/billingexpr/settle.go b/pkg/billingexpr/settle.go new file mode 100644 index 00000000..7e69b9e3 --- /dev/null +++ b/pkg/billingexpr/settle.go @@ -0,0 +1,24 @@ +package billingexpr + +// ComputeTieredQuota runs the Expr from a frozen BillingSnapshot against +// actual token counts and returns the settlement result. +func ComputeTieredQuota(snap *BillingSnapshot, params TokenParams) (TieredResult, error) { + return ComputeTieredQuotaWithRequest(snap, params, RequestInput{}) +} + +func ComputeTieredQuotaWithRequest(snap *BillingSnapshot, params TokenParams, request RequestInput) (TieredResult, error) { + cost, trace, err := RunExprByHashWithRequest(snap.ExprString, snap.ExprHash, params, request) + if err != nil { + return TieredResult{}, err + } + + afterGroup := QuotaRound(cost * snap.GroupRatio) + crossed := trace.MatchedTier != snap.EstimatedTier + + return TieredResult{ + ActualQuotaBeforeGroup: cost, + ActualQuotaAfterGroup: afterGroup, + MatchedTier: trace.MatchedTier, + CrossedTier: crossed, + }, nil +} diff --git a/pkg/billingexpr/types.go b/pkg/billingexpr/types.go new file mode 100644 index 00000000..193f82b4 --- /dev/null +++ b/pkg/billingexpr/types.go @@ -0,0 +1,59 @@ +package billingexpr + +import ( + "crypto/sha256" + "fmt" +) + +type RequestInput struct { + Headers map[string]string + Body []byte +} + +// TokenParams holds all token dimensions passed into an Expr evaluation. +// Fields beyond P and C are optional — when absent they default to 0, +// which means cache-unaware expressions keep working unchanged. +type TokenParams struct { + P float64 // prompt tokens + C float64 // completion tokens + CR float64 // cache read (hit) tokens + CC float64 // cache creation tokens (5-min TTL for Claude, generic for others) + CC1h float64 // cache creation tokens — 1-hour TTL (Claude only) +} + +// TraceResult holds side-channel info captured by the tier() function +// during Expr execution. This replaces the old Breakdown mechanism — +// the Expr itself is the single source of truth for billing logic. +type TraceResult struct { + MatchedTier string `json:"matched_tier"` + Cost float64 `json:"cost"` +} + +// BillingSnapshot captures the billing rule state frozen at pre-consume time. +// It is fully serializable and contains no compiled program pointers. +type BillingSnapshot struct { + BillingMode string `json:"billing_mode"` + ModelName string `json:"model_name"` + ExprString string `json:"expr_string"` + ExprHash string `json:"expr_hash"` + GroupRatio float64 `json:"group_ratio"` + EstimatedPromptTokens int `json:"estimated_prompt_tokens"` + EstimatedCompletionTokens int `json:"estimated_completion_tokens"` + EstimatedQuotaBeforeGroup float64 `json:"estimated_quota_before_group"` + EstimatedQuotaAfterGroup int `json:"estimated_quota_after_group"` + EstimatedTier string `json:"estimated_tier"` +} + +// TieredResult holds everything needed after running tiered settlement. +type TieredResult struct { + ActualQuotaBeforeGroup float64 `json:"actual_quota_before_group"` + ActualQuotaAfterGroup int `json:"actual_quota_after_group"` + MatchedTier string `json:"matched_tier"` + CrossedTier bool `json:"crossed_tier"` +} + +// ExprHashString returns the SHA-256 hex digest of an expression string. +func ExprHashString(expr string) string { + h := sha256.Sum256([]byte(expr)) + return fmt.Sprintf("%x", h) +} diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 5c34b792..e16ec29b 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -46,7 +46,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type resp, err := adaptor.DoRequest(c, info, ioReader) if err != nil { - return types.NewError(err, types.ErrorCodeDoRequestFailed) + return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } statusCodeMappingStr := c.GetString("status_code_mapping") diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go index 8f69b937..7a2eb9aa 100644 --- a/relay/chat_completions_via_responses.go +++ b/relay/chat_completions_via_responses.go @@ -2,6 +2,7 @@ package relay import ( "bytes" + "io" "net/http" "strings" @@ -124,8 +125,10 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } + var requestBody io.Reader = bytes.NewBuffer(jsonData) + var httpResp *http.Response - resp, err := adaptor.DoRequest(c, info, bytes.NewBuffer(jsonData)) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } diff --git a/relay/common/billing.go b/relay/common/billing.go index 78f5cb19..3971426b 100644 --- a/relay/common/billing.go +++ b/relay/common/billing.go @@ -18,4 +18,7 @@ type BillingSettler interface { // GetPreConsumedQuota 返回实际预扣的额度值(信任用户可能为 0)。 GetPreConsumedQuota() int + + // Reserve 将预扣额度补到目标值;若目标值不高于当前预扣额度则不做任何事。 + Reserve(targetQuota int) error } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 8b0789c0..a130c6f9 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -10,6 +10,7 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/pkg/billingexpr" relayconstant "github.com/QuantumNous/new-api/relay/constant" "github.com/QuantumNous/new-api/setting/model_setting" "github.com/QuantumNous/new-api/types" @@ -152,6 +153,11 @@ type RelayInfo struct { PriceData types.PriceData + // TieredBillingSnapshot is a frozen snapshot of tiered billing rules + // captured at pre-consume time. Non-nil only when billing mode is "tiered_expr". + TieredBillingSnapshot *billingexpr.BillingSnapshot + BillingRequestInput *billingexpr.RequestInput + Request dto.Request // RequestConversionChain records request format conversions in order, e.g. diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index f60a485b..392677af 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" "github.com/QuantumNous/new-api/relay/helper" @@ -236,6 +237,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat()) } + // Tiered billing early return + if ok, tieredQuota, tieredResult := service.TryTieredSettle(relayInfo, billingexpr.TokenParams{ + P: float64(usage.PromptTokens), + C: float64(usage.CompletionTokens), + CR: float64(usage.PromptTokensDetails.CachedTokens), + CC: float64(usage.PromptTokensDetails.CachedCreationTokens - usage.ClaudeCacheCreation1hTokens), + CC1h: float64(usage.ClaudeCacheCreation1hTokens), + }); ok { + postConsumeQuotaTiered(ctx, relayInfo, usage, tieredQuota, tieredResult, extraContent...) + return + } + adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason) useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() @@ -506,3 +519,51 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage Other: other, }) } + +func postConsumeQuotaTiered(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, quota int, tieredResult *service.TieredResultWrapper, extraContent ...string) { + _ = tieredResult // will be used for log enrichment + + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() + modelName := relayInfo.OriginModelName + tokenName := ctx.GetString("token_name") + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + + totalTokens := usage.PromptTokens + usage.CompletionTokens + + if totalTokens == 0 { + quota = 0 + extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)") + logger.LogError(ctx, fmt.Sprintf("tiered billing: total tokens is 0, userId %d, channelId %d, tokenId %d, model %s, pre-consumed %d", + relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) + } else { + if groupRatio != 0 && quota == 0 { + quota = 1 + } + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + + if err := service.SettleBilling(ctx, relayInfo, quota); err != nil { + logger.LogError(ctx, "error settling tiered billing: "+err.Error()) + } + + logModel := modelName + logContent := strings.Join(extraContent, ", ") + + other := service.GenerateTieredOtherInfo(ctx, relayInfo, tieredResult) + + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + ModelName: logModel, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) +} diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index d8ca4223..7a73aa6e 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -3,6 +3,7 @@ package relay import ( "bytes" "fmt" + "io" "net/http" "github.com/QuantumNous/new-api/common" @@ -58,7 +59,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * } logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData))) - requestBody := bytes.NewBuffer(jsonData) + var requestBody io.Reader = bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { diff --git a/relay/helper/billing_expr_request.go b/relay/helper/billing_expr_request.go new file mode 100644 index 00000000..404a348f --- /dev/null +++ b/relay/helper/billing_expr_request.go @@ -0,0 +1,54 @@ +package helper + +import ( + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/pkg/billingexpr" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/gin-gonic/gin" +) + +func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.RelayInfo) (billingexpr.RequestInput, error) { + input := billingexpr.RequestInput{} + if info != nil { + input.Headers = cloneStringMap(info.RequestHeaders) + } + + bodyBytes, err := readIncomingBillingExprBody(c) + if err != nil { + return billingexpr.RequestInput{}, err + } + input.Body = bodyBytes + return input, nil +} + +func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) { + if c == nil || c.Request == nil || !isJSONContentType(c.Request.Header.Get("Content-Type")) { + return nil, nil + } + storage, err := common.GetBodyStorage(c) + if err != nil { + return nil, err + } + return storage.Bytes() +} + +func isJSONContentType(contentType string) bool { + contentType = strings.ToLower(strings.TrimSpace(contentType)) + return strings.HasPrefix(contentType, "application/json") +} + +func cloneStringMap(src map[string]string) map[string]string { + if len(src) == 0 { + return map[string]string{} + } + dst := make(map[string]string, len(src)) + for key, value := range src { + if strings.TrimSpace(key) == "" { + continue + } + dst[key] = value + } + return dst +} diff --git a/relay/helper/billing_expr_request_test.go b/relay/helper/billing_expr_request_test.go new file mode 100644 index 00000000..c07aaa29 --- /dev/null +++ b/relay/helper/billing_expr_request_test.go @@ -0,0 +1,35 @@ +package helper + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestResolveIncomingBillingExprRequestInput(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + ctx.Request.Header.Set("Content-Type", "application/json") + + body := []byte(`{"service_tier":"fast"}`) + ctx.Request.Body = io.NopCloser(bytes.NewReader(body)) + ctx.Set(common.KeyRequestBody, body) + + info := &relaycommon.RelayInfo{ + RequestHeaders: map[string]string{"Content-Type": "application/json"}, + } + + input, err := ResolveIncomingBillingExprRequestInput(ctx, info) + require.NoError(t, err) + require.Equal(t, body, input.Body) + require.Equal(t, "application/json", input.Headers["Content-Type"]) +} diff --git a/relay/helper/price.go b/relay/helper/price.go index f109040d..798f4e8d 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -5,7 +5,9 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/billing_setting" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/QuantumNous/new-api/types" @@ -50,6 +52,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens groupRatioInfo := HandleGroupRatio(c, info) + // Check if this model uses tiered_expr billing + if billing_setting.GetBillingMode(info.OriginModelName) == billing_setting.BillingModeTieredExpr { + return modelPriceHelperTiered(c, info, promptTokens, meta, groupRatioInfo) + } + var preConsumedQuota int var modelRatio float64 var completionRatio float64 @@ -195,5 +202,73 @@ func ContainPriceOrRatio(modelName string) bool { if ok { return true } + if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr { + _, ok = billing_setting.GetBillingExpr(modelName) + return ok + } return false } + +func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) { + exprStr, ok := billing_setting.GetBillingExpr(info.OriginModelName) + if !ok { + return types.PriceData{}, fmt.Errorf("model %s is configured as tiered_expr but has no billing expression", info.OriginModelName) + } + + estimatedCompletionTokens := 0 + if meta.MaxTokens != 0 { + estimatedCompletionTokens = meta.MaxTokens + } + + requestInput, err := ResolveIncomingBillingExprRequestInput(c, info) + if err != nil { + return types.PriceData{}, err + } + + rawQuota, trace, err := billingexpr.RunExprWithRequest(exprStr, billingexpr.TokenParams{ + P: float64(promptTokens), + C: float64(estimatedCompletionTokens), + }, requestInput) + if err != nil { + return types.PriceData{}, fmt.Errorf("model %s tiered expr run failed: %w", info.OriginModelName, err) + } + + preConsumedQuota := billingexpr.QuotaRound(rawQuota * groupRatioInfo.GroupRatio) + + freeModel := false + if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume { + if groupRatioInfo.GroupRatio == 0 || rawQuota == 0 { + preConsumedQuota = 0 + freeModel = true + } + } + + exprHash := billingexpr.ExprHashString(exprStr) + snapshot := &billingexpr.BillingSnapshot{ + BillingMode: billing_setting.BillingModeTieredExpr, + ModelName: info.OriginModelName, + ExprString: exprStr, + ExprHash: exprHash, + GroupRatio: groupRatioInfo.GroupRatio, + EstimatedPromptTokens: promptTokens, + EstimatedCompletionTokens: estimatedCompletionTokens, + EstimatedQuotaBeforeGroup: rawQuota, + EstimatedQuotaAfterGroup: preConsumedQuota, + EstimatedTier: trace.MatchedTier, + } + info.TieredBillingSnapshot = snapshot + info.BillingRequestInput = &requestInput + + priceData := types.PriceData{ + FreeModel: freeModel, + GroupRatioInfo: groupRatioInfo, + QuotaToPreConsume: preConsumedQuota, + } + + if common.DebugEnabled { + println(fmt.Sprintf("model_price_helper_tiered result: model=%s preConsume=%d rawQuota=%.2f groupRatio=%.2f tier=%s", info.OriginModelName, preConsumedQuota, rawQuota, groupRatioInfo.GroupRatio, trace.MatchedTier)) + } + + info.PriceData = priceData + return priceData, nil +} diff --git a/service/billing_session.go b/service/billing_session.go index f24b68e5..3e3ef1f4 100644 --- a/service/billing_session.go +++ b/service/billing_session.go @@ -27,6 +27,8 @@ type BillingSession struct { funding FundingSource preConsumedQuota int // 实际预扣额度(信任用户可能为 0) tokenConsumed int // 令牌额度实际扣减量 + extraReserved int // 发送前补充预扣的额度(订阅退款时需要单独回滚) + trusted bool // 是否命中信任额度旁路 fundingSettled bool // funding.Settle 已成功,资金来源已提交 settled bool // Settle 全部完成(资金 + 令牌) refunded bool // Refund 已调用 @@ -97,6 +99,8 @@ func (s *BillingSession) Refund(c *gin.Context) { tokenKey := s.relayInfo.TokenKey isPlayground := s.relayInfo.IsPlayground tokenConsumed := s.tokenConsumed + extraReserved := s.extraReserved + subscriptionId := s.relayInfo.SubscriptionId funding := s.funding gopool.Go(func() { @@ -104,6 +108,11 @@ func (s *BillingSession) Refund(c *gin.Context) { if err := funding.Refund(); err != nil { common.SysLog("error refunding billing source: " + err.Error()) } + if extraReserved > 0 && funding.Source() == BillingSourceSubscription && subscriptionId > 0 { + if err := model.PostConsumeUserSubscriptionDelta(subscriptionId, -int64(extraReserved)); err != nil { + common.SysLog("error refunding subscription extra reserved quota: " + err.Error()) + } + } // 2) 退还令牌额度 if tokenConsumed > 0 && !isPlayground { if err := model.IncreaseTokenQuota(tokenId, tokenKey, tokenConsumed); err != nil { @@ -140,6 +149,34 @@ func (s *BillingSession) GetPreConsumedQuota() int { return s.preConsumedQuota } +func (s *BillingSession) Reserve(targetQuota int) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.settled || s.refunded || s.trusted || targetQuota <= s.preConsumedQuota { + return nil + } + + delta := targetQuota - s.preConsumedQuota + if delta <= 0 { + return nil + } + + if err := s.reserveFunding(delta); err != nil { + return err + } + if err := s.reserveToken(delta); err != nil { + s.rollbackFundingReserve(delta) + return err + } + + s.preConsumedQuota += delta + s.tokenConsumed += delta + s.extraReserved += delta + s.syncRelayInfo() + return nil +} + // --------------------------------------------------------------------------- // PreConsume — 统一预扣费入口(含信任额度旁路) // --------------------------------------------------------------------------- @@ -151,6 +188,7 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro // ---- 信任额度旁路 ---- if s.shouldTrust(c) { + s.trusted = true effectiveQuota = 0 logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足, 信任且不需要预扣费 (funding=%s)", s.relayInfo.UserId, s.funding.Source())) } else if effectiveQuota > 0 { @@ -191,6 +229,55 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro return nil } +func (s *BillingSession) reserveFunding(delta int) error { + switch funding := s.funding.(type) { + case *WalletFunding: + if err := model.DecreaseUserQuota(funding.userId, delta); err != nil { + return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + } + funding.consumed += delta + return nil + case *SubscriptionFunding: + if err := model.PostConsumeUserSubscriptionDelta(funding.subscriptionId, int64(delta)); err != nil { + return types.NewErrorWithStatusCode( + fmt.Errorf("订阅额度不足或未配置订阅: %s", err.Error()), + types.ErrorCodeInsufficientUserQuota, + http.StatusForbidden, + types.ErrOptionWithSkipRetry(), + types.ErrOptionWithNoRecordErrorLog(), + ) + } + return nil + default: + return types.NewError(fmt.Errorf("unsupported funding source: %s", s.funding.Source()), types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + } +} + +func (s *BillingSession) rollbackFundingReserve(delta int) { + switch funding := s.funding.(type) { + case *WalletFunding: + if err := model.IncreaseUserQuota(funding.userId, delta, false); err != nil { + common.SysLog("error rolling back wallet funding reserve: " + err.Error()) + } else { + funding.consumed -= delta + } + case *SubscriptionFunding: + if err := model.PostConsumeUserSubscriptionDelta(funding.subscriptionId, -int64(delta)); err != nil { + common.SysLog("error rolling back subscription funding reserve: " + err.Error()) + } + } +} + +func (s *BillingSession) reserveToken(delta int) error { + if delta <= 0 || s.relayInfo.IsPlayground { + return nil + } + if err := PreConsumeTokenQuota(s.relayInfo, delta); err != nil { + return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + return nil +} + // shouldTrust 统一信任额度检查,适用于钱包和订阅。 func (s *BillingSession) shouldTrust(c *gin.Context) bool { // 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路 @@ -235,10 +322,10 @@ func (s *BillingSession) syncRelayInfo() { if sub, ok := s.funding.(*SubscriptionFunding); ok { info.SubscriptionId = sub.subscriptionId - info.SubscriptionPreConsumed = sub.preConsumed + info.SubscriptionPreConsumed = sub.preConsumed + int64(s.extraReserved) info.SubscriptionPostDelta = 0 info.SubscriptionAmountTotal = sub.AmountTotal - info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter + info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter + int64(s.extraReserved) info.SubscriptionPlanId = sub.PlanId info.SubscriptionPlanTitle = sub.PlanTitle } else { diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 1c440911..0737a911 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -6,6 +6,7 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/types" @@ -214,3 +215,42 @@ func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.Price appendRequestPath(nil, relayInfo, other) return other } + +func GenerateTieredOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, result *billingexpr.TieredResult) map[string]interface{} { + other := make(map[string]interface{}) + other["billing_mode"] = "tiered_expr" + + snap := relayInfo.TieredBillingSnapshot + if snap != nil { + other["group_ratio"] = snap.GroupRatio + other["expr_hash"] = snap.ExprHash + other["estimated_prompt_tokens"] = snap.EstimatedPromptTokens + other["estimated_completion_tokens"] = snap.EstimatedCompletionTokens + other["estimated_quota_before_group"] = snap.EstimatedQuotaBeforeGroup + other["estimated_quota_after_group"] = snap.EstimatedQuotaAfterGroup + other["estimated_tier"] = snap.EstimatedTier + } + + if result != nil { + other["actual_quota_before_group"] = result.ActualQuotaBeforeGroup + other["actual_quota_after_group"] = result.ActualQuotaAfterGroup + other["matched_tier"] = result.MatchedTier + other["crossed_tier"] = result.CrossedTier + } + + other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli()) + if relayInfo.IsModelMapped { + other["is_model_mapped"] = true + other["upstream_model_name"] = relayInfo.UpstreamModelName + } + + adminInfo := make(map[string]interface{}) + adminInfo["use_channel"] = ctx.GetStringSlice("use_channel") + AppendChannelAffinityAdminInfo(ctx, adminInfo) + other["admin_info"] = adminInfo + + appendRequestPath(ctx, relayInfo, other) + appendRequestConversionChain(relayInfo, other) + appendBillingInfo(relayInfo, other) + return other +} diff --git a/service/quota.go b/service/quota.go index 7ee70edd..eafec351 100644 --- a/service/quota.go +++ b/service/quota.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/QuantumNous/new-api/setting/system_setting" @@ -157,6 +158,15 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, usage *dto.RealtimeUsage, extraContent string) { + // Tiered billing early return + if ok, tieredQuota, tieredResult := TryTieredSettle(relayInfo, billingexpr.TokenParams{ + P: float64(usage.InputTokens), + C: float64(usage.OutputTokens), + }); ok { + postConsumeQuotaTieredService(ctx, relayInfo, modelName, usage.InputTokens, usage.OutputTokens, usage.TotalTokens, tieredQuota, tieredResult, extraContent) + return + } + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.InputTokenDetails.TextTokens textOutTokens := usage.OutputTokenDetails.TextTokens @@ -240,6 +250,18 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat()) } + // Tiered billing early return + if ok, tieredQuota, tieredResult := TryTieredSettle(relayInfo, billingexpr.TokenParams{ + P: float64(usage.PromptTokens), + C: float64(usage.CompletionTokens), + CR: float64(usage.PromptTokensDetails.CachedTokens), + CC: float64(usage.PromptTokensDetails.CachedCreationTokens - usage.ClaudeCacheCreation1hTokens), + CC1h: float64(usage.ClaudeCacheCreation1hTokens), + }); ok { + postConsumeQuotaTieredService(ctx, relayInfo, relayInfo.OriginModelName, usage.PromptTokens, usage.CompletionTokens, usage.PromptTokens+usage.CompletionTokens, tieredQuota, tieredResult, "") + return + } + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens @@ -360,6 +382,16 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { + // Tiered billing early return + if ok, tieredQuota, tieredResult := TryTieredSettle(relayInfo, billingexpr.TokenParams{ + P: float64(usage.PromptTokens), + C: float64(usage.CompletionTokens), + CR: float64(usage.PromptTokensDetails.CachedTokens), + }); ok { + postConsumeQuotaTieredService(ctx, relayInfo, relayInfo.OriginModelName, usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens, tieredQuota, tieredResult, extraContent) + return + } + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.PromptTokensDetails.TextTokens textOutTokens := usage.CompletionTokenDetails.TextTokens @@ -607,3 +639,50 @@ func checkAndSendSubscriptionQuotaNotify(relayInfo *relaycommon.RelayInfo) { } }) } + +func postConsumeQuotaTieredService(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, + promptTokens, completionTokens, totalTokens, quota int, tieredResult *TieredResultWrapper, extraContent string) { + + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() + tokenName := ctx.GetString("token_name") + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + + var logContent string + if totalTokens == 0 { + quota = 0 + logContent = "上游没有返回计费信息(可能是上游超时)" + logger.LogError(ctx, fmt.Sprintf("tiered billing: total tokens is 0, userId %d, channelId %d, tokenId %d, model %s, pre-consumed %d", + relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) + } else { + if groupRatio != 0 && quota == 0 { + quota = 1 + } + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + + if err := SettleBilling(ctx, relayInfo, quota); err != nil { + logger.LogError(ctx, "error settling tiered billing: "+err.Error()) + } + + if extraContent != "" { + logContent += extraContent + } + + other := GenerateTieredOtherInfo(ctx, relayInfo, tieredResult) + + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + ModelName: modelName, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) +} diff --git a/service/tiered_settle.go b/service/tiered_settle.go new file mode 100644 index 00000000..83f5b6f6 --- /dev/null +++ b/service/tiered_settle.go @@ -0,0 +1,36 @@ +package service + +import ( + "github.com/QuantumNous/new-api/pkg/billingexpr" + relaycommon "github.com/QuantumNous/new-api/relay/common" +) + +// TieredResultWrapper wraps billingexpr.TieredResult for use at the service layer. +type TieredResultWrapper = billingexpr.TieredResult + +// TryTieredSettle checks if the request uses tiered_expr billing and, if so, +// computes the actual quota using the frozen BillingSnapshot. Returns: +// - ok=true, quota, result when tiered billing applies +// - ok=false, 0, nil when it doesn't (caller should fall through to existing logic) +func TryTieredSettle(relayInfo *relaycommon.RelayInfo, params billingexpr.TokenParams) (ok bool, quota int, result *billingexpr.TieredResult) { + snap := relayInfo.TieredBillingSnapshot + if snap == nil || snap.BillingMode != "tiered_expr" { + return false, 0, nil + } + + requestInput := billingexpr.RequestInput{} + if relayInfo.BillingRequestInput != nil { + requestInput = *relayInfo.BillingRequestInput + } + + tr, err := billingexpr.ComputeTieredQuotaWithRequest(snap, params, requestInput) + if err != nil { + quota = relayInfo.FinalPreConsumedQuota + if quota <= 0 { + quota = snap.EstimatedQuotaAfterGroup + } + return true, quota, nil + } + + return true, tr.ActualQuotaAfterGroup, &tr +} diff --git a/service/tiered_settle_test.go b/service/tiered_settle_test.go new file mode 100644 index 00000000..4ec702ad --- /dev/null +++ b/service/tiered_settle_test.go @@ -0,0 +1,401 @@ +package service + +import ( + "testing" + + "github.com/QuantumNous/new-api/pkg/billingexpr" + relaycommon "github.com/QuantumNous/new-api/relay/common" +) + +// Claude Sonnet-style tiered expression: standard vs long-context +const sonnetTieredExpr = `p <= 200000 ? tier("standard", p * 1.5 + c * 7.5) : tier("long_context", p * 3 + c * 11.25)` + +// Simple flat expression +const flatExpr = `tier("default", p * 2 + c * 10)` + +// Expression with cache tokens +const cacheExpr = `tier("default", p * 2 + c * 10 + cr * 0.2 + cc * 2.5 + cc1h * 4)` + +// Expression with request probes +const probeExpr = `param("service_tier") == "fast" ? tier("fast", p * 4 + c * 20) : tier("normal", p * 2 + c * 10)` + +func makeSnapshot(expr string, groupRatio float64, estPrompt, estCompletion int) *billingexpr.BillingSnapshot { + return &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: expr, + ExprHash: billingexpr.ExprHashString(expr), + GroupRatio: groupRatio, + EstimatedPromptTokens: estPrompt, + EstimatedCompletionTokens: estCompletion, + } +} + +func makeRelayInfo(expr string, groupRatio float64, estPrompt, estCompletion int) *relaycommon.RelayInfo { + snap := makeSnapshot(expr, groupRatio, estPrompt, estCompletion) + cost, trace, _ := billingexpr.RunExpr(expr, billingexpr.TokenParams{P: float64(estPrompt), C: float64(estCompletion)}) + snap.EstimatedQuotaBeforeGroup = cost + snap.EstimatedQuotaAfterGroup = billingexpr.QuotaRound(cost * groupRatio) + snap.EstimatedTier = trace.MatchedTier + return &relaycommon.RelayInfo{ + TieredBillingSnapshot: snap, + FinalPreConsumedQuota: snap.EstimatedQuotaAfterGroup, + } +} + +// --------------------------------------------------------------------------- +// Existing tests (preserved) +// --------------------------------------------------------------------------- + +func TestTryTieredSettleUsesFrozenRequestInput(t *testing.T) { + exprStr := `param("service_tier") == "fast" ? tier("fast", p * 2) : tier("normal", p)` + relayInfo := &relaycommon.RelayInfo{ + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + EstimatedPromptTokens: 100, + EstimatedCompletionTokens: 0, + EstimatedQuotaAfterGroup: 100, + }, + BillingRequestInput: &billingexpr.RequestInput{ + Body: []byte(`{"service_tier":"fast"}`), + }, + } + + ok, quota, result := TryTieredSettle(relayInfo, billingexpr.TokenParams{P: 100}) + if !ok { + t.Fatal("expected tiered settle to apply") + } + if quota != 200 { + t.Fatalf("quota = %d, want 200", quota) + } + if result == nil || result.MatchedTier != "fast" { + t.Fatalf("matched tier = %v, want fast", result) + } +} + +func TestTryTieredSettleFallsBackToFrozenPreConsumeOnExprError(t *testing.T) { + relayInfo := &relaycommon.RelayInfo{ + FinalPreConsumedQuota: 321, + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: `invalid +-+ expr`, + ExprHash: billingexpr.ExprHashString(`invalid +-+ expr`), + GroupRatio: 1.0, + EstimatedQuotaAfterGroup: 123, + }, + } + + ok, quota, result := TryTieredSettle(relayInfo, billingexpr.TokenParams{P: 100}) + if !ok { + t.Fatal("expected tiered settle to apply") + } + if quota != 321 { + t.Fatalf("quota = %d, want 321", quota) + } + if result != nil { + t.Fatalf("result = %#v, want nil", result) + } +} + +// --------------------------------------------------------------------------- +// Pre-consume vs Post-consume consistency +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_PreConsumeMatchesPostConsume(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.0, 1000, 500) + params := billingexpr.TokenParams{P: 1000, C: 500} + + ok, quota, _ := TryTieredSettle(info, params) + if !ok { + t.Fatal("expected tiered settle") + } + // p*2 + c*10 = 2000 + 5000 = 7000 + if quota != 7000 { + t.Fatalf("quota = %d, want 7000", quota) + } + if quota != info.FinalPreConsumedQuota { + t.Fatalf("pre-consume %d != post-consume %d", info.FinalPreConsumedQuota, quota) + } +} + +func TestTryTieredSettle_PostConsumeOverPreConsume(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.0, 1000, 500) + preConsumed := info.FinalPreConsumedQuota // 7000 + + // Actual usage is higher than estimated + params := billingexpr.TokenParams{P: 2000, C: 1000} + ok, quota, _ := TryTieredSettle(info, params) + if !ok { + t.Fatal("expected tiered settle") + } + // p*2 + c*10 = 4000 + 10000 = 14000 + if quota != 14000 { + t.Fatalf("quota = %d, want 14000", quota) + } + if quota <= preConsumed { + t.Fatalf("expected supplement: actual %d should > pre-consumed %d", quota, preConsumed) + } +} + +func TestTryTieredSettle_PostConsumeUnderPreConsume(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.0, 1000, 500) + preConsumed := info.FinalPreConsumedQuota // 7000 + + // Actual usage is lower than estimated + params := billingexpr.TokenParams{P: 100, C: 50} + ok, quota, _ := TryTieredSettle(info, params) + if !ok { + t.Fatal("expected tiered settle") + } + // p*2 + c*10 = 200 + 500 = 700 + if quota != 700 { + t.Fatalf("quota = %d, want 700", quota) + } + if quota >= preConsumed { + t.Fatalf("expected refund: actual %d should < pre-consumed %d", quota, preConsumed) + } +} + +// --------------------------------------------------------------------------- +// Tiered boundary conditions +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_ExactBoundary(t *testing.T) { + info := makeRelayInfo(sonnetTieredExpr, 1.0, 200000, 1000) + + // p == 200000 => standard tier (p <= 200000) + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 200000, C: 1000}) + if !ok { + t.Fatal("expected tiered settle") + } + // standard: p*1.5 + c*7.5 = 300000 + 7500 = 307500 + if quota != 307500 { + t.Fatalf("quota = %d, want 307500", quota) + } + if result.MatchedTier != "standard" { + t.Fatalf("tier = %s, want standard", result.MatchedTier) + } +} + +func TestTryTieredSettle_BoundaryPlusOne(t *testing.T) { + info := makeRelayInfo(sonnetTieredExpr, 1.0, 200000, 1000) + + // p == 200001 => crosses to long_context tier + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 200001, C: 1000}) + if !ok { + t.Fatal("expected tiered settle") + } + // long_context: p*3 + c*11.25 = 600003 + 11250 = 611253 + if quota != 611253 { + t.Fatalf("quota = %d, want 611253", quota) + } + if result.MatchedTier != "long_context" { + t.Fatalf("tier = %s, want long_context", result.MatchedTier) + } + if !result.CrossedTier { + t.Fatal("expected CrossedTier = true") + } +} + +func TestTryTieredSettle_ZeroTokens(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.0, 0, 0) + + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 0, C: 0}) + if !ok { + t.Fatal("expected tiered settle") + } + if quota != 0 { + t.Fatalf("quota = %d, want 0", quota) + } + if result == nil { + t.Fatal("result should not be nil") + } +} + +func TestTryTieredSettle_HugeTokens(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.0, 10000000, 5000000) + + ok, quota, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 10000000, C: 5000000}) + if !ok { + t.Fatal("expected tiered settle") + } + // p*2 + c*10 = 20000000 + 50000000 = 70000000 + if quota != 70000000 { + t.Fatalf("quota = %d, want 70000000", quota) + } +} + +func TestTryTieredSettle_CacheTokensAffectSettlement(t *testing.T) { + info := makeRelayInfo(cacheExpr, 1.0, 1000, 500) + + // Without cache tokens + ok1, quota1, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if !ok1 { + t.Fatal("expected tiered settle") + } + // p*2 + c*10 + cr*0.2 + cc*2.5 + cc1h*4 = 2000 + 5000 + 0 + 0 + 0 = 7000 + + // With cache tokens + ok2, quota2, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500, CR: 10000, CC: 5000, CC1h: 2000}) + if !ok2 { + t.Fatal("expected tiered settle") + } + // 2000 + 5000 + 10000*0.2 + 5000*2.5 + 2000*4 = 2000 + 5000 + 2000 + 12500 + 8000 = 29500 + + if quota2 <= quota1 { + t.Fatalf("cache tokens should increase quota: without=%d, with=%d", quota1, quota2) + } + if quota1 != 7000 { + t.Fatalf("no-cache quota = %d, want 7000", quota1) + } + if quota2 != 29500 { + t.Fatalf("cache quota = %d, want 29500", quota2) + } +} + +// --------------------------------------------------------------------------- +// Request probe tests +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_RequestProbeInfluencesBilling(t *testing.T) { + info := makeRelayInfo(probeExpr, 1.0, 1000, 500) + info.BillingRequestInput = &billingexpr.RequestInput{ + Body: []byte(`{"service_tier":"fast"}`), + } + + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if !ok { + t.Fatal("expected tiered settle") + } + // fast: p*4 + c*20 = 4000 + 10000 = 14000 + if quota != 14000 { + t.Fatalf("quota = %d, want 14000", quota) + } + if result.MatchedTier != "fast" { + t.Fatalf("tier = %s, want fast", result.MatchedTier) + } +} + +func TestTryTieredSettle_NoRequestInput_FallsBackToDefault(t *testing.T) { + info := makeRelayInfo(probeExpr, 1.0, 1000, 500) + // No BillingRequestInput set — param("service_tier") returns nil, not "fast" + + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if !ok { + t.Fatal("expected tiered settle") + } + // normal: p*2 + c*10 = 2000 + 5000 = 7000 + if quota != 7000 { + t.Fatalf("quota = %d, want 7000", quota) + } + if result.MatchedTier != "normal" { + t.Fatalf("tier = %s, want normal", result.MatchedTier) + } +} + +// --------------------------------------------------------------------------- +// Group ratio tests +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_GroupRatioScaling(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.5, 1000, 500) + + ok, quota, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if !ok { + t.Fatal("expected tiered settle") + } + // cost = 7000, after group = round(7000 * 1.5) = 10500 + if quota != 10500 { + t.Fatalf("quota = %d, want 10500", quota) + } +} + +func TestTryTieredSettle_GroupRatioZero(t *testing.T) { + info := makeRelayInfo(flatExpr, 0, 1000, 500) + + ok, quota, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if !ok { + t.Fatal("expected tiered settle") + } + if quota != 0 { + t.Fatalf("quota = %d, want 0 (group ratio = 0)", quota) + } +} + +// --------------------------------------------------------------------------- +// Ratio mode (negative tests) — TryTieredSettle must return false +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_RatioMode_NilSnapshot(t *testing.T) { + info := &relaycommon.RelayInfo{ + TieredBillingSnapshot: nil, + } + + ok, _, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if ok { + t.Fatal("expected TryTieredSettle to return false when snapshot is nil") + } +} + +func TestTryTieredSettle_RatioMode_WrongBillingMode(t *testing.T) { + info := &relaycommon.RelayInfo{ + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "ratio", + ExprString: flatExpr, + ExprHash: billingexpr.ExprHashString(flatExpr), + GroupRatio: 1.0, + }, + } + + ok, _, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if ok { + t.Fatal("expected TryTieredSettle to return false for ratio billing mode") + } +} + +func TestTryTieredSettle_RatioMode_EmptyBillingMode(t *testing.T) { + info := &relaycommon.RelayInfo{ + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "", + ExprString: flatExpr, + ExprHash: billingexpr.ExprHashString(flatExpr), + GroupRatio: 1.0, + }, + } + + ok, _, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if ok { + t.Fatal("expected TryTieredSettle to return false for empty billing mode") + } +} + +// --------------------------------------------------------------------------- +// Fallback tests +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_ErrorFallbackToEstimatedQuotaAfterGroup(t *testing.T) { + info := &relaycommon.RelayInfo{ + FinalPreConsumedQuota: 0, + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: `invalid expr!!!`, + ExprHash: billingexpr.ExprHashString(`invalid expr!!!`), + GroupRatio: 1.0, + EstimatedQuotaAfterGroup: 999, + }, + } + + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 100}) + if !ok { + t.Fatal("expected tiered settle to apply") + } + // FinalPreConsumedQuota is 0, should fall back to EstimatedQuotaAfterGroup + if quota != 999 { + t.Fatalf("quota = %d, want 999", quota) + } + if result != nil { + t.Fatal("result should be nil on error fallback") + } +} diff --git a/setting/billing_setting/tiered_billing.go b/setting/billing_setting/tiered_billing.go new file mode 100644 index 00000000..85d5e628 --- /dev/null +++ b/setting/billing_setting/tiered_billing.go @@ -0,0 +1,135 @@ +package billing_setting + +import ( + "fmt" + "sync" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/pkg/billingexpr" +) + +var ( + mu sync.RWMutex + + // model -> "ratio" | "tiered_expr" + billingModeMap = make(map[string]string) + + // model -> expr string (authored by frontend, stored directly) + billingExprMap = make(map[string]string) +) + +const ( + BillingModeRatio = "ratio" + BillingModeTieredExpr = "tiered_expr" +) + +// --------------------------------------------------------------------------- +// Read accessors (hot path, must be fast) +// --------------------------------------------------------------------------- + +func GetBillingMode(model string) string { + mu.RLock() + defer mu.RUnlock() + if mode, ok := billingModeMap[model]; ok { + return mode + } + return BillingModeRatio +} + +func GetBillingExpr(model string) (string, bool) { + mu.RLock() + defer mu.RUnlock() + expr, ok := billingExprMap[model] + return expr, ok +} + +func UpdateBillingModeByJSONString(jsonStr string) error { + var m map[string]string + if err := common.Unmarshal([]byte(jsonStr), &m); err != nil { + return fmt.Errorf("parse ModelBillingMode: %w", err) + } + for k, v := range m { + if v != BillingModeRatio && v != BillingModeTieredExpr { + return fmt.Errorf("invalid billing mode %q for model %q", v, k) + } + } + mu.Lock() + billingModeMap = m + mu.Unlock() + return nil +} + +func UpdateBillingExprByJSONString(jsonStr string) error { + var m map[string]string + if err := common.Unmarshal([]byte(jsonStr), &m); err != nil { + return fmt.Errorf("parse ModelBillingExpr: %w", err) + } + for model, exprStr := range m { + if _, err := billingexpr.CompileFromCache(exprStr); err != nil { + return fmt.Errorf("model %q: %w", model, err) + } + if err := smokeTestExpr(exprStr); err != nil { + return fmt.Errorf("model %q smoke test: %w", model, err) + } + } + mu.Lock() + billingExprMap = m + mu.Unlock() + billingexpr.InvalidateCache() + return nil +} + +// --------------------------------------------------------------------------- +// JSON serializers (for OptionMap / API response) +// --------------------------------------------------------------------------- + +func BillingMode2JSONString() string { + mu.RLock() + defer mu.RUnlock() + b, err := common.Marshal(billingModeMap) + if err != nil { + return "{}" + } + return string(b) +} + +func BillingExpr2JSONString() string { + mu.RLock() + defer mu.RUnlock() + b, err := common.Marshal(billingExprMap) + if err != nil { + return "{}" + } + return string(b) +} + +func smokeTestExpr(exprStr string) error { + vectors := []billingexpr.TokenParams{ + {P: 0, C: 0}, + {P: 1000, C: 1000}, + {P: 100000, C: 100000}, + {P: 1000000, C: 1000000}, + } + requests := []billingexpr.RequestInput{ + {}, + { + Headers: map[string]string{ + "anthropic-beta": "fast-mode-2026-02-01", + }, + Body: []byte(`{"service_tier":"fast","stream_options":{"include_usage":true},"messages":[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]}`), + }, + } + + for _, v := range vectors { + for _, request := range requests { + result, _, err := billingexpr.RunExprWithRequest(exprStr, v, request) + if err != nil { + return fmt.Errorf("vector {p=%g, c=%g}: run failed: %w", v.P, v.C, err) + } + if result < 0 { + return fmt.Errorf("vector {p=%g, c=%g}: result %f < 0", v.P, v.C, result) + } + } + } + return nil +} diff --git a/setting/model_setting/claude_test.go b/setting/model_setting/claude_test.go new file mode 100644 index 00000000..0a806a7a --- /dev/null +++ b/setting/model_setting/claude_test.go @@ -0,0 +1,60 @@ +package model_setting + +import ( + "net/http" + "testing" +) + +func TestClaudeSettingsWriteHeadersMergesConfiguredValuesIntoSingleHeader(t *testing.T) { + settings := &ClaudeSettings{ + HeadersSettings: map[string]map[string][]string{ + "claude-3-7-sonnet-20250219-thinking": { + "anthropic-beta": { + "token-efficient-tools-2025-02-19", + }, + }, + }, + } + + headers := http.Header{} + headers.Set("anthropic-beta", "output-128k-2025-02-19") + + settings.WriteHeaders("claude-3-7-sonnet-20250219-thinking", &headers) + + got := headers.Values("anthropic-beta") + if len(got) != 1 { + t.Fatalf("expected a single merged header value, got %v", got) + } + expected := "output-128k-2025-02-19,token-efficient-tools-2025-02-19" + if got[0] != expected { + t.Fatalf("expected merged header %q, got %q", expected, got[0]) + } +} + +func TestClaudeSettingsWriteHeadersDeduplicatesAcrossCommaSeparatedAndRepeatedValues(t *testing.T) { + settings := &ClaudeSettings{ + HeadersSettings: map[string]map[string][]string{ + "claude-3-7-sonnet-20250219-thinking": { + "anthropic-beta": { + "token-efficient-tools-2025-02-19", + "computer-use-2025-01-24", + }, + }, + }, + } + + headers := http.Header{} + headers.Add("anthropic-beta", "output-128k-2025-02-19, token-efficient-tools-2025-02-19") + headers.Add("anthropic-beta", "token-efficient-tools-2025-02-19") + + settings.WriteHeaders("claude-3-7-sonnet-20250219-thinking", &headers) + + got := headers.Values("anthropic-beta") + if len(got) != 1 { + t.Fatalf("expected duplicate values to collapse into one header, got %v", got) + } + expected := "output-128k-2025-02-19,token-efficient-tools-2025-02-19,computer-use-2025-01-24" + if got[0] != expected { + t.Fatalf("expected deduplicated merged header %q, got %q", expected, got[0]) + } +} diff --git a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx index 61104d22..6f5021b4 100644 --- a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx +++ b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx @@ -377,6 +377,43 @@ function renderCompactDetailSummary(summarySegments) { ); } +function buildTieredBillingSegments(other, t) { + const segments = [ + { text: `${t('阶梯计费')}`, tone: 'primary' }, + ]; + + if (other.matched_tier) { + segments.push({ + text: `${t('命中档位')}: ${other.matched_tier}`, + tone: 'secondary', + }); + } + + const groupRatio = other.group_ratio; + if (groupRatio !== undefined && groupRatio !== null) { + segments.push({ + text: `${t('分组')} ${formatRatio(groupRatio)}x`, + tone: 'secondary', + }); + } + + if (other.crossed_tier) { + segments.push({ + text: `${t('跨阶梯')}: ${t('是')}`, + tone: 'secondary', + }); + } + + if (other.actual_quota_after_group !== undefined) { + segments.push({ + text: `${t('实际额度')}: ${other.actual_quota_after_group}`, + tone: 'secondary', + }); + } + + return { segments }; +} + function getUsageLogDetailSummary(record, text, billingDisplayMode, t) { const other = getLogOther(record.other); @@ -414,6 +451,10 @@ function getUsageLogDetailSummary(record, text, billingDisplayMode, t) { }; } + if (other?.billing_mode === 'tiered_expr') { + return buildTieredBillingSegments(other, t); + } + return { segments: other?.claude ? renderModelPriceSimple( diff --git a/web/src/hooks/usage-logs/useUsageLogsData.jsx b/web/src/hooks/usage-logs/useUsageLogsData.jsx index ebe1b882..84019a22 100644 --- a/web/src/hooks/usage-logs/useUsageLogsData.jsx +++ b/web/src/hooks/usage-logs/useUsageLogsData.jsx @@ -559,6 +559,66 @@ export const useLogsData = () => { value: other.reasoning_effort, }); } + if (other?.billing_mode === 'tiered_expr') { + expandDataLocal.push({ + key: t('计费方式'), + value: t('阶梯计费'), + }); + if (other?.group_ratio !== undefined) { + const gr = other.group_ratio; + expandDataLocal.push({ + key: t('分组倍率'), + value: typeof gr === 'number' ? gr.toFixed(4) : String(gr ?? '-'), + }); + } + if (other?.rule_version !== undefined) { + expandDataLocal.push({ + key: t('规则版本'), + value: String(other.rule_version), + }); + } + if (other?.estimated_env) { + expandDataLocal.push({ + key: t('预估环境'), + value: `prompt=${other.estimated_env.prompt_tokens ?? 0}, completion=${other.estimated_env.completion_tokens ?? 0}`, + }); + } + if (other?.actual_env) { + expandDataLocal.push({ + key: t('实际环境'), + value: `prompt=${other.actual_env.prompt_tokens ?? 0}, completion=${other.actual_env.completion_tokens ?? 0}`, + }); + } + if (other?.estimated_quota_after_group !== undefined) { + expandDataLocal.push({ + key: t('预估额度'), + value: String(other.estimated_quota_after_group), + }); + } + if (other?.actual_quota_after_group !== undefined) { + expandDataLocal.push({ + key: t('实际额度'), + value: String(other.actual_quota_after_group), + }); + } + expandDataLocal.push({ + key: t('跨阶梯'), + value: other?.crossed_tier ? t('是') : t('否'), + }); + if (Array.isArray(other?.breakdown) && other.breakdown.length > 0) { + const breakdownText = other.breakdown.map((item, idx) => + `[${idx}] ${item.token_type} | tokens=${item.tokens_in_tier} | cost=${item.unit_cost} | flat=${item.flat_fee} | sub=${item.subtotal}` + ).join('\n'); + expandDataLocal.push({ + key: t('计费明细'), + value: ( +
p ({t('输入 Token')}), c (
+ {t('输出 Token')}), cr ({t('缓存读取')}),{' '}
+ cc ({t('缓存创建')}),{' '}
+ cc1h ({t('缓存创建-1小时')})
+ tier(name, value),{' '}
+ max(a, b), min(a, b),{' '}
+ ceil(x), floor(x),{' '}
+ abs(x), header(name),{' '}
+ param(path), has(source, text)
+ prompt_tokens,{' '}
+ completion_tokens, cache_read_tokens,{' '}
+ cache_create_tokens,{' '}
+ cache_create_1h_tokens
+