From 91ed4e196a7bd92d9020aeadabc533178379186c Mon Sep 17 00:00:00 2001 From: CaIon Date: Mon, 16 Mar 2026 16:00:22 +0800 Subject: [PATCH] feat: implement tiered billing expression evaluation and related functionality - Added support for tiered billing expressions in the billing system. - Introduced new types and functions for handling billing expressions, including caching and execution. - Updated existing billing logic to accommodate tiered billing scenarios. - Enhanced request handling to support incoming billing expression requests. - Added tests for tiered billing functionality to ensure correctness. --- .gitignore | 3 + go.mod | 1 + go.sum | 2 + model/option.go | 7 + pkg/billingexpr/billingexpr_test.go | 933 +++++++++++ pkg/billingexpr/compile.go | 91 + pkg/billingexpr/round.go | 10 + pkg/billingexpr/run.go | 139 ++ pkg/billingexpr/settle.go | 24 + pkg/billingexpr/types.go | 59 + relay/audio_handler.go | 2 +- relay/chat_completions_via_responses.go | 5 +- relay/common/billing.go | 3 + relay/common/relay_info.go | 6 + relay/compatible_handler.go | 61 + relay/embedding_handler.go | 3 +- relay/helper/billing_expr_request.go | 54 + relay/helper/billing_expr_request_test.go | 35 + relay/helper/price.go | 75 + service/billing_session.go | 91 +- service/log_info_generate.go | 40 + service/quota.go | 79 + service/tiered_settle.go | 36 + service/tiered_settle_test.go | 401 +++++ setting/billing_setting/tiered_billing.go | 135 ++ setting/model_setting/claude_test.go | 60 + .../table/usage-logs/UsageLogsColumnDefs.jsx | 41 + web/src/hooks/usage-logs/useUsageLogsData.jsx | 60 + web/src/i18n/locales/en.json | 114 +- web/src/i18n/locales/zh-CN.json | 112 +- .../Ratio/components/ModelPricingEditor.jsx | 54 +- .../Ratio/components/TieredPricingEditor.jsx | 1488 +++++++++++++++++ .../Ratio/components/requestRuleExpr.js | 443 +++++ .../Ratio/hooks/useModelPricingEditorState.js | 156 +- 34 files changed, 4797 insertions(+), 26 deletions(-) create mode 100644 pkg/billingexpr/billingexpr_test.go create mode 100644 pkg/billingexpr/compile.go create mode 100644 pkg/billingexpr/round.go create mode 100644 pkg/billingexpr/run.go create mode 100644 pkg/billingexpr/settle.go create mode 100644 pkg/billingexpr/types.go create mode 100644 relay/helper/billing_expr_request.go create mode 100644 relay/helper/billing_expr_request_test.go create mode 100644 service/tiered_settle.go create mode 100644 service/tiered_settle_test.go create mode 100644 setting/billing_setting/tiered_billing.go create mode 100644 setting/model_setting/claude_test.go create mode 100644 web/src/pages/Setting/Ratio/components/TieredPricingEditor.jsx create mode 100644 web/src/pages/Setting/Ratio/components/requestRuleExpr.js diff --git a/.gitignore b/.gitignore index 54fa8311..b5d65bcf 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,6 @@ data/ .gomodcache/ .gocache-temp .gopath + +token_estimator_test.go +skills-lock.json \ No newline at end of file diff --git a/go.mod b/go.mod index 2f28f781..43808976 100644 --- a/go.mod +++ b/go.mod @@ -75,6 +75,7 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/expr-lang/expr v1.17.8 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect diff --git a/go.sum b/go.sum index 74298929..eb2be8a4 100644 --- a/go.sum +++ b/go.sum @@ -68,6 +68,8 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/expr-lang/expr v1.17.8 h1:W1loDTT+0PQf5YteHSTpju2qfUfNoBt4yw9+wOEU9VM= +github.com/expr-lang/expr v1.17.8/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= diff --git a/model/option.go b/model/option.go index 697e77df..b768addc 100644 --- a/model/option.go +++ b/model/option.go @@ -7,6 +7,7 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/setting" + "github.com/QuantumNous/new-api/setting/billing_setting" "github.com/QuantumNous/new-api/setting/config" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/setting/performance_setting" @@ -121,6 +122,8 @@ func InitOptionMap() { common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString() common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString() + common.OptionMap["ModelBillingMode"] = billing_setting.BillingMode2JSONString() + common.OptionMap["ModelBillingExpr"] = billing_setting.BillingExpr2JSONString() common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString() common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink @@ -436,6 +439,10 @@ func updateOptionMap(key string, value string) (err error) { err = ratio_setting.UpdateAudioRatioByJSONString(value) case "AudioCompletionRatio": err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value) + case "ModelBillingMode": + err = billing_setting.UpdateBillingModeByJSONString(value) + case "ModelBillingExpr": + err = billing_setting.UpdateBillingExprByJSONString(value) case "TopUpLink": common.TopUpLink = value //case "ChatLink": diff --git a/pkg/billingexpr/billingexpr_test.go b/pkg/billingexpr/billingexpr_test.go new file mode 100644 index 00000000..34ec0fc0 --- /dev/null +++ b/pkg/billingexpr/billingexpr_test.go @@ -0,0 +1,933 @@ +package billingexpr_test + +import ( + "math" + "math/rand" + "testing" + + "github.com/QuantumNous/new-api/pkg/billingexpr" +) + +// --------------------------------------------------------------------------- +// Claude-style: fixed tiers, input > 200K changes both input & output price +// --------------------------------------------------------------------------- + +const claudeExpr = `p <= 200000 ? tier("standard", p * 1.5 + c * 7.5) : tier("long_context", p * 3.0 + c * 11.25)` + +func TestClaude_StandardTier(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{P: 100000, C: 5000}) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard") + } +} + +func TestClaude_LongContextTier(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{P: 300000, C: 10000}) + if err != nil { + t.Fatal(err) + } + want := 300000*3.0 + 10000*11.25 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "long_context" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "long_context") + } +} + +func TestClaude_BoundaryExact(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{P: 200000, C: 1000}) + if err != nil { + t.Fatal(err) + } + want := 200000*1.5 + 1000*7.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard") + } +} + +// --------------------------------------------------------------------------- +// GLM-style: multi-condition tiers with both input and output dimensions +// --------------------------------------------------------------------------- + +const glmExpr = ` +( + p < 32000 && c < 200 ? tier("tier1_short", (p)*2 + c*8) : + p < 32000 && c >= 200 ? tier("tier2_long_output", (p)*3 + c*14) : + tier("tier3_long_input", (p)*4 + c*16) +) / 1000000 +` + +func TestGLM_Tier1(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(glmExpr, billingexpr.TokenParams{P: 15000, C: 100}) + if err != nil { + t.Fatal(err) + } + want := (15000.0*2 + 100.0*8) / 1000000 + if math.Abs(cost-want) > 1e-10 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "tier1_short" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "tier1_short") + } +} + +func TestGLM_Tier2(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(glmExpr, billingexpr.TokenParams{P: 15000, C: 500}) + if err != nil { + t.Fatal(err) + } + want := (15000.0*3 + 500.0*14) / 1000000 + if math.Abs(cost-want) > 1e-10 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "tier2_long_output" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "tier2_long_output") + } +} + +func TestGLM_Tier3(t *testing.T) { + cost, trace, err := billingexpr.RunExpr(glmExpr, billingexpr.TokenParams{P: 50000, C: 100}) + if err != nil { + t.Fatal(err) + } + want := (50000.0*4 + 100.0*16) / 1000000 + if math.Abs(cost-want) > 1e-10 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "tier3_long_input" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "tier3_long_input") + } +} + +// --------------------------------------------------------------------------- +// Simple flat-rate (no tier() call) +// --------------------------------------------------------------------------- + +func TestSimpleExpr_NoTier(t *testing.T) { + cost, trace, err := billingexpr.RunExpr("p * 0.5 + c * 1.0", billingexpr.TokenParams{P: 1000, C: 500}) + if err != nil { + t.Fatal(err) + } + want := 1000*0.5 + 500*1.0 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "" { + t.Errorf("tier should be empty, got %q", trace.MatchedTier) + } +} + +// --------------------------------------------------------------------------- +// Math helper functions +// --------------------------------------------------------------------------- + +func TestMathHelpers(t *testing.T) { + cost, _, err := billingexpr.RunExpr("max(p, c) * 0.5 + min(p, c) * 0.1", billingexpr.TokenParams{P: 300, C: 500}) + if err != nil { + t.Fatal(err) + } + want := 500*0.5 + 300*0.1 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestRequestProbeHelpers(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `prompt_tokens * 0.5 + completion_tokens * 1.0 * (param("service_tier") == "fast" ? 2 : 1)`, + billingexpr.TokenParams{P: 1000, C: 500}, + billingexpr.RequestInput{ + Body: []byte(`{"service_tier":"fast"}`), + }, + ) + if err != nil { + t.Fatal(err) + } + want := 1000*0.5 + 500*1.0*2 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestHeaderProbeHelper(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `p * 0.5 + c * 1.0 * (has(header("anthropic-beta"), "fast-mode") ? 2 : 1)`, + billingexpr.TokenParams{P: 1000, C: 500}, + billingexpr.RequestInput{ + Headers: map[string]string{ + "Anthropic-Beta": "fast-mode-2026-02-01", + }, + }, + ) + if err != nil { + t.Fatal(err) + } + want := 1000*0.5 + 500*1.0*2 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestParamProbeNestedBool(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `p * (param("stream_options.fast_mode") == true ? 1.5 : 1.0)`, + billingexpr.TokenParams{P: 100}, + billingexpr.RequestInput{ + Body: []byte(`{"stream_options":{"fast_mode":true}}`), + }, + ) + if err != nil { + t.Fatal(err) + } + want := 150.0 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestParamProbeArrayLength(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `p * (param("messages.#") > 20 ? 1.2 : 1.0)`, + billingexpr.TokenParams{P: 100}, + billingexpr.RequestInput{ + Body: []byte(`{"messages":[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]}`), + }, + ) + if err != nil { + t.Fatal(err) + } + want := 120.0 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestRequestProbeMissingFieldReturnsNil(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `param("missing.value") == nil ? 2 : 1`, + billingexpr.TokenParams{}, + billingexpr.RequestInput{ + Body: []byte(`{"service_tier":"standard"}`), + }, + ) + if err != nil { + t.Fatal(err) + } + if cost != 2 { + t.Errorf("cost = %f, want 2", cost) + } +} + +func TestRequestProbeMultipleRulesMultiply(t *testing.T) { + cost, _, err := billingexpr.RunExprWithRequest( + `(param("service_tier") == "fast" ? 2 : 1) * (has(header("anthropic-beta"), "fast-mode-2026-02-01") ? 2.5 : 1)`, + billingexpr.TokenParams{}, + billingexpr.RequestInput{ + Headers: map[string]string{ + "Anthropic-Beta": "fast-mode-2026-02-01", + }, + Body: []byte(`{"service_tier":"fast"}`), + }, + ) + if err != nil { + t.Fatal(err) + } + if math.Abs(cost-5) > 1e-6 { + t.Errorf("cost = %f, want 5", cost) + } +} + +func TestCeilFloor(t *testing.T) { + cost, _, err := billingexpr.RunExpr("ceil(p / 1000) * 0.5", billingexpr.TokenParams{P: 1500}) + if err != nil { + t.Fatal(err) + } + want := math.Ceil(1500.0/1000) * 0.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +// --------------------------------------------------------------------------- +// Zero tokens +// --------------------------------------------------------------------------- + +func TestZeroTokens(t *testing.T) { + cost, _, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{}) + if err != nil { + t.Fatal(err) + } + if cost != 0 { + t.Errorf("cost should be 0 for zero tokens, got %f", cost) + } +} + +// --------------------------------------------------------------------------- +// Rounding +// --------------------------------------------------------------------------- + +func TestQuotaRound(t *testing.T) { + tests := []struct { + in float64 + want int + }{ + {0, 0}, + {0.4, 0}, + {0.5, 1}, + {0.6, 1}, + {1.5, 2}, + {-0.5, -1}, + {-0.6, -1}, + {999.4999, 999}, + {999.5, 1000}, + {1e9 + 0.5, 1e9 + 1}, + } + for _, tt := range tests { + got := billingexpr.QuotaRound(tt.in) + if got != tt.want { + t.Errorf("QuotaRound(%f) = %d, want %d", tt.in, got, tt.want) + } + } +} + +// --------------------------------------------------------------------------- +// Settlement +// --------------------------------------------------------------------------- + +func TestComputeTieredQuota_Basic(t *testing.T) { + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: claudeExpr, + ExprHash: billingexpr.ExprHashString(claudeExpr), + GroupRatio: 1.0, + EstimatedPromptTokens: 100000, + EstimatedCompletionTokens: 5000, + EstimatedQuotaBeforeGroup: 100000*1.5 + 5000*7.5, + EstimatedQuotaAfterGroup: billingexpr.QuotaRound(100000*1.5 + 5000*7.5), + EstimatedTier: "standard", + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 300000, C: 10000}) + if err != nil { + t.Fatal(err) + } + + wantBefore := 300000*3.0 + 10000*11.25 + if math.Abs(result.ActualQuotaBeforeGroup-wantBefore) > 1e-6 { + t.Errorf("before group: got %f, want %f", result.ActualQuotaBeforeGroup, wantBefore) + } + if result.MatchedTier != "long_context" { + t.Errorf("tier = %q, want %q", result.MatchedTier, "long_context") + } + if !result.CrossedTier { + t.Error("expected crossed_tier=true (estimated standard, actual long_context)") + } +} + +func TestComputeTieredQuota_SameTier(t *testing.T) { + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: claudeExpr, + ExprHash: billingexpr.ExprHashString(claudeExpr), + GroupRatio: 1.5, + EstimatedPromptTokens: 50000, + EstimatedCompletionTokens: 1000, + EstimatedQuotaBeforeGroup: 50000*1.5 + 1000*7.5, + EstimatedQuotaAfterGroup: billingexpr.QuotaRound((50000*1.5 + 1000*7.5) * 1.5), + EstimatedTier: "standard", + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 80000, C: 2000}) + if err != nil { + t.Fatal(err) + } + + wantBefore := 80000*1.5 + 2000*7.5 + wantAfter := billingexpr.QuotaRound(wantBefore * 1.5) + if result.ActualQuotaAfterGroup != wantAfter { + t.Errorf("after group: got %d, want %d", result.ActualQuotaAfterGroup, wantAfter) + } + if result.CrossedTier { + t.Error("expected crossed_tier=false (both standard)") + } +} + +// --------------------------------------------------------------------------- +// Compile errors +// --------------------------------------------------------------------------- + +func TestCompileError(t *testing.T) { + _, _, err := billingexpr.RunExpr("invalid +-+ syntax", billingexpr.TokenParams{}) + if err == nil { + t.Error("expected compile error") + } +} + +// --------------------------------------------------------------------------- +// Compile Cache +// --------------------------------------------------------------------------- + +func TestCompileCache_SameResult(t *testing.T) { + r1, _, err := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + r2, _, err := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + if r1 != r2 { + t.Errorf("cached and uncached results differ: %f != %f", r1, r2) + } +} + +func TestInvalidateCache(t *testing.T) { + billingexpr.InvalidateCache() + r1, _, _ := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100}) + billingexpr.InvalidateCache() + r2, _, _ := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100}) + if r1 != r2 { + t.Errorf("post-invalidate results differ: %f != %f", r1, r2) + } +} + +// --------------------------------------------------------------------------- +// Hash +// --------------------------------------------------------------------------- + +func TestExprHashString_Deterministic(t *testing.T) { + h1 := billingexpr.ExprHashString("p * 0.5") + h2 := billingexpr.ExprHashString("p * 0.5") + if h1 != h2 { + t.Error("hash should be deterministic") + } + h3 := billingexpr.ExprHashString("p * 0.6") + if h1 == h3 { + t.Error("different expressions should have different hashes") + } +} + +// --------------------------------------------------------------------------- +// Cache variables: present +// --------------------------------------------------------------------------- + +const claudeWithCacheExpr = `p <= 200000 ? tier("standard", p * 1.5 + c * 7.5 + cr * 0.15 + cc * 1.875) : tier("long_context", p * 3.0 + c * 11.25 + cr * 0.3 + cc * 3.75)` + +func TestCachePresent_StandardTier(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 50000, CC: 10000} + cost, trace, err := billingexpr.RunExpr(claudeWithCacheExpr, params) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + 50000*0.15 + 10000*1.875 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard") + } +} + +func TestCachePresent_LongContextTier(t *testing.T) { + params := billingexpr.TokenParams{P: 300000, C: 10000, CR: 100000, CC: 20000} + cost, trace, err := billingexpr.RunExpr(claudeWithCacheExpr, params) + if err != nil { + t.Fatal(err) + } + want := 300000*3.0 + 10000*11.25 + 100000*0.3 + 20000*3.75 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "long_context" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "long_context") + } +} + +// --------------------------------------------------------------------------- +// Cache variables: absent (all zero) — same expression still works +// --------------------------------------------------------------------------- + +func TestCacheAbsent_ZeroCacheTokens(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 5000} + cost, trace, err := billingexpr.RunExpr(claudeWithCacheExpr, params) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f (cache terms should be 0)", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard") + } +} + +// --------------------------------------------------------------------------- +// Mixed cache fields: cc and cc1h non-zero +// --------------------------------------------------------------------------- + +const claudeCacheSplitExpr = `tier("default", p * 1.5 + c * 7.5 + cr * 0.15 + cc * 2.0 + cc1h * 3.0)` + +func TestMixedCacheFields(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 10000, CC: 5000, CC1h: 2000} + cost, _, err := billingexpr.RunExpr(claudeCacheSplitExpr, params) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + 10000*0.15 + 5000*2.0 + 2000*3.0 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } +} + +func TestMixedCacheFields_AllCacheZero(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 5000} + cost, _, err := billingexpr.RunExpr(claudeCacheSplitExpr, params) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f (all cache zero)", cost, want) + } +} + +// --------------------------------------------------------------------------- +// Backward compatibility: p+c only expressions still work with TokenParams +// --------------------------------------------------------------------------- + +func TestBackwardCompat_OldExprWithTokenParams(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 99999, CC: 88888} + cost, trace, err := billingexpr.RunExpr(claudeExpr, params) + if err != nil { + t.Fatal(err) + } + want := 100000*1.5 + 5000*7.5 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f (old expr ignores cache fields)", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard") + } +} + +// --------------------------------------------------------------------------- +// Settlement with cache tokens +// --------------------------------------------------------------------------- + +func TestComputeTieredQuota_WithCache(t *testing.T) { + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: claudeWithCacheExpr, + ExprHash: billingexpr.ExprHashString(claudeWithCacheExpr), + GroupRatio: 1.0, + EstimatedPromptTokens: 100000, + EstimatedCompletionTokens: 5000, + EstimatedQuotaBeforeGroup: 100000*1.5 + 5000*7.5, + EstimatedQuotaAfterGroup: billingexpr.QuotaRound(100000*1.5 + 5000*7.5), + EstimatedTier: "standard", + } + + params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 50000, CC: 10000} + result, err := billingexpr.ComputeTieredQuota(snap, params) + if err != nil { + t.Fatal(err) + } + + wantBefore := 100000*1.5 + 5000*7.5 + 50000*0.15 + 10000*1.875 + if math.Abs(result.ActualQuotaBeforeGroup-wantBefore) > 1e-6 { + t.Errorf("before group: got %f, want %f", result.ActualQuotaBeforeGroup, wantBefore) + } + if result.MatchedTier != "standard" { + t.Errorf("tier = %q, want %q", result.MatchedTier, "standard") + } + if result.CrossedTier { + t.Error("expected crossed_tier=false (same tier)") + } +} + +func TestComputeTieredQuota_WithCacheCrossTier(t *testing.T) { + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: claudeWithCacheExpr, + ExprHash: billingexpr.ExprHashString(claudeWithCacheExpr), + GroupRatio: 2.0, + EstimatedPromptTokens: 100000, + EstimatedCompletionTokens: 5000, + EstimatedQuotaBeforeGroup: 100000*1.5 + 5000*7.5, + EstimatedQuotaAfterGroup: billingexpr.QuotaRound((100000*1.5 + 5000*7.5) * 2.0), + EstimatedTier: "standard", + } + + params := billingexpr.TokenParams{P: 300000, C: 10000, CR: 50000, CC: 10000} + result, err := billingexpr.ComputeTieredQuota(snap, params) + if err != nil { + t.Fatal(err) + } + + wantBefore := 300000*3.0 + 10000*11.25 + 50000*0.3 + 10000*3.75 + wantAfter := billingexpr.QuotaRound(wantBefore * 2.0) + if math.Abs(result.ActualQuotaBeforeGroup-wantBefore) > 1e-6 { + t.Errorf("before group: got %f, want %f", result.ActualQuotaBeforeGroup, wantBefore) + } + if result.ActualQuotaAfterGroup != wantAfter { + t.Errorf("after group: got %d, want %d", result.ActualQuotaAfterGroup, wantAfter) + } + if !result.CrossedTier { + t.Error("expected crossed_tier=true (estimated standard, actual long_context)") + } +} + +// --------------------------------------------------------------------------- +// Fuzz: random p/c/cache, verify non-negative result +// --------------------------------------------------------------------------- + +func TestFuzz_NonNegativeResults(t *testing.T) { + exprs := []string{ + claudeExpr, + claudeWithCacheExpr, + glmExpr, + "p * 0.5 + c * 1.0", + "max(p, c) * 0.1", + "p * 0.5 + cr * 0.1 + cc * 0.2", + } + + rng := rand.New(rand.NewSource(42)) + + for _, exprStr := range exprs { + for i := 0; i < 500; i++ { + params := billingexpr.TokenParams{ + P: math.Round(rng.Float64() * 1000000), + C: math.Round(rng.Float64() * 500000), + CR: math.Round(rng.Float64() * 200000), + CC: math.Round(rng.Float64() * 50000), + CC1h: math.Round(rng.Float64() * 10000), + } + + cost, _, err := billingexpr.RunExpr(exprStr, params) + if err != nil { + t.Fatalf("expr=%q params=%+v: %v", exprStr, params, err) + } + if cost < 0 { + t.Errorf("expr=%q params=%+v: negative cost %f", exprStr, params, cost) + } + } + } +} + +func TestFuzz_SettlementConsistency(t *testing.T) { + rng := rand.New(rand.NewSource(99)) + + for i := 0; i < 200; i++ { + estParams := billingexpr.TokenParams{ + P: math.Round(rng.Float64() * 500000), + C: math.Round(rng.Float64() * 100000), + CR: math.Round(rng.Float64() * 100000), + CC: math.Round(rng.Float64() * 30000), + } + actParams := billingexpr.TokenParams{ + P: math.Round(rng.Float64() * 500000), + C: math.Round(rng.Float64() * 100000), + CR: math.Round(rng.Float64() * 100000), + CC: math.Round(rng.Float64() * 30000), + } + groupRatio := 0.5 + rng.Float64()*2.0 + + estCost, estTrace, _ := billingexpr.RunExpr(claudeWithCacheExpr, estParams) + + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: claudeWithCacheExpr, + ExprHash: billingexpr.ExprHashString(claudeWithCacheExpr), + GroupRatio: groupRatio, + EstimatedPromptTokens: int(estParams.P), + EstimatedCompletionTokens: int(estParams.C), + EstimatedQuotaBeforeGroup: estCost, + EstimatedQuotaAfterGroup: billingexpr.QuotaRound(estCost * groupRatio), + EstimatedTier: estTrace.MatchedTier, + } + + result, err := billingexpr.ComputeTieredQuota(snap, actParams) + if err != nil { + t.Fatalf("iter %d: %v", i, err) + } + + directCost, _, _ := billingexpr.RunExpr(claudeWithCacheExpr, actParams) + directQuota := billingexpr.QuotaRound(directCost * groupRatio) + + if result.ActualQuotaAfterGroup != directQuota { + t.Errorf("iter %d: settlement %d != direct %d", i, result.ActualQuotaAfterGroup, directQuota) + } + } +} + +// --------------------------------------------------------------------------- +// Settlement-level tests for ComputeTieredQuota +// --------------------------------------------------------------------------- + +func TestComputeTieredQuota_BasicSettlement(t *testing.T) { + exprStr := `tier("default", p + c)` + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 3000, C: 2000}) + if err != nil { + t.Fatal(err) + } + if math.Abs(result.ActualQuotaBeforeGroup-5000) > 1e-6 { + t.Errorf("before group = %f, want 5000", result.ActualQuotaBeforeGroup) + } + if result.ActualQuotaAfterGroup != 5000 { + t.Errorf("after group = %d, want 5000", result.ActualQuotaAfterGroup) + } + if result.MatchedTier != "default" { + t.Errorf("tier = %q, want default", result.MatchedTier) + } +} + +func TestComputeTieredQuota_WithGroupRatio(t *testing.T) { + exprStr := `tier("default", p + c)` + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 2.0, + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 1000, C: 500}) + if err != nil { + t.Fatal(err) + } + // cost = 1500, after group = round(1500 * 2.0) = 3000 + if result.ActualQuotaAfterGroup != 3000 { + t.Errorf("after group = %d, want 3000", result.ActualQuotaAfterGroup) + } +} + +func TestComputeTieredQuota_ZeroTokens(t *testing.T) { + exprStr := `tier("default", p * 2 + c * 10)` + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{}) + if err != nil { + t.Fatal(err) + } + if result.ActualQuotaAfterGroup != 0 { + t.Errorf("after group = %d, want 0", result.ActualQuotaAfterGroup) + } +} + +func TestComputeTieredQuota_RoundingEdge(t *testing.T) { + exprStr := `tier("default", p * 0.5)` // 3 * 0.5 = 1.5 -> round to 2 + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 3}) + if err != nil { + t.Fatal(err) + } + // 3 * 0.5 = 1.5, round(1.5) = 2 + if result.ActualQuotaAfterGroup != 2 { + t.Errorf("after group = %d, want 2 (round 1.5 up)", result.ActualQuotaAfterGroup) + } +} + +func TestComputeTieredQuota_RoundingEdgeDown(t *testing.T) { + exprStr := `tier("default", p * 0.4)` // 3 * 0.4 = 1.2 -> round to 1 + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + } + + result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 3}) + if err != nil { + t.Fatal(err) + } + // 3 * 0.4 = 1.2, round(1.2) = 1 + if result.ActualQuotaAfterGroup != 1 { + t.Errorf("after group = %d, want 1 (round 1.2 down)", result.ActualQuotaAfterGroup) + } +} + +func TestComputeTieredQuotaWithRequest_ProbeAffectsQuota(t *testing.T) { + exprStr := `param("fast") == true ? tier("fast", p * 4) : tier("normal", p * 2)` + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + EstimatedTier: "normal", + } + + // Without request: normal tier + r1, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 1000}) + if err != nil { + t.Fatal(err) + } + if r1.ActualQuotaAfterGroup != 2000 { + t.Errorf("normal = %d, want 2000", r1.ActualQuotaAfterGroup) + } + + // With request: fast tier + r2, err := billingexpr.ComputeTieredQuotaWithRequest(snap, billingexpr.TokenParams{P: 1000}, billingexpr.RequestInput{ + Body: []byte(`{"fast":true}`), + }) + if err != nil { + t.Fatal(err) + } + if r2.ActualQuotaAfterGroup != 4000 { + t.Errorf("fast = %d, want 4000", r2.ActualQuotaAfterGroup) + } + if !r2.CrossedTier { + t.Error("expected CrossedTier = true when probe changes tier") + } +} + +func TestComputeTieredQuota_BoundaryTierCrossing(t *testing.T) { + exprStr := `p <= 100000 ? tier("small", p * 1) : tier("large", p * 2)` + snap := &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + EstimatedTier: "small", + } + + // At boundary + r1, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 100000}) + if err != nil { + t.Fatal(err) + } + if r1.MatchedTier != "small" { + t.Errorf("at boundary: tier = %s, want small", r1.MatchedTier) + } + if r1.ActualQuotaAfterGroup != 100000 { + t.Errorf("at boundary: quota = %d, want 100000", r1.ActualQuotaAfterGroup) + } + + // Past boundary + r2, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 100001}) + if err != nil { + t.Fatal(err) + } + if r2.MatchedTier != "large" { + t.Errorf("past boundary: tier = %s, want large", r2.MatchedTier) + } + if r2.ActualQuotaAfterGroup != 200002 { + t.Errorf("past boundary: quota = %d, want 200002", r2.ActualQuotaAfterGroup) + } + if !r2.CrossedTier { + t.Error("expected CrossedTier = true") + } +} + +// --------------------------------------------------------------------------- +// Time function tests +// --------------------------------------------------------------------------- + +func TestTimeFunctions_ValidTimezone(t *testing.T) { + exprStr := `tier("default", p) * (hour("UTC") >= 0 ? 1 : 1)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + if cost != 100 { + t.Errorf("cost = %f, want 100", cost) + } +} + +func TestTimeFunctions_AllFunctionsCompile(t *testing.T) { + exprStr := `tier("default", p) * (hour("Asia/Shanghai") >= 0 ? 1 : 1) * (minute("UTC") >= 0 ? 1 : 1) * (weekday("UTC") >= 0 ? 1 : 1) * (month("UTC") >= 1 ? 1 : 1) * (day("UTC") >= 1 ? 1 : 1)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 500}) + if err != nil { + t.Fatal(err) + } + if cost != 500 { + t.Errorf("cost = %f, want 500", cost) + } +} + +func TestTimeFunctions_InvalidTimezone(t *testing.T) { + exprStr := `tier("default", p) * (hour("Invalid/Zone") >= 0 ? 1 : 2)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + // Invalid timezone falls back to UTC; hour is 0-23, so condition is always true + if cost != 100 { + t.Errorf("cost = %f, want 100 (fallback to UTC)", cost) + } +} + +func TestTimeFunctions_EmptyTimezone(t *testing.T) { + exprStr := `tier("default", p) * (hour("") >= 0 ? 1 : 2)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + if cost != 100 { + t.Errorf("cost = %f, want 100 (empty tz -> UTC)", cost) + } +} + +func TestTimeFunctions_NightDiscountPattern(t *testing.T) { + exprStr := `tier("default", p * 2 + c * 10) * (hour("UTC") >= 21 || hour("UTC") < 6 ? 0.5 : 1)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 1000, C: 500}) + if err != nil { + t.Fatal(err) + } + // Base = 1000*2 + 500*10 = 7000; multiplier is either 0.5 or 1 depending on current UTC hour + if cost != 7000 && cost != 3500 { + t.Errorf("cost = %f, want 7000 or 3500", cost) + } +} + +func TestTimeFunctions_WeekdayRange(t *testing.T) { + exprStr := `tier("default", p) * (weekday("UTC") >= 0 && weekday("UTC") <= 6 ? 1 : 999)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100}) + if err != nil { + t.Fatal(err) + } + // weekday is always 0-6, so multiplier is always 1 + if cost != 100 { + t.Errorf("cost = %f, want 100", cost) + } +} + +func TestTimeFunctions_MonthDayPattern(t *testing.T) { + exprStr := `tier("default", p) * (month("Asia/Shanghai") == 1 && day("Asia/Shanghai") == 1 ? 0.5 : 1)` + cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 1000}) + if err != nil { + t.Fatal(err) + } + // Either 1000 (not Jan 1) or 500 (Jan 1) — both are valid + if cost != 1000 && cost != 500 { + t.Errorf("cost = %f, want 1000 or 500", cost) + } +} diff --git a/pkg/billingexpr/compile.go b/pkg/billingexpr/compile.go new file mode 100644 index 00000000..4a8b61e9 --- /dev/null +++ b/pkg/billingexpr/compile.go @@ -0,0 +1,91 @@ +package billingexpr + +import ( + "fmt" + "math" + "sync" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" +) + +const maxCacheSize = 256 + +var ( + cacheMu sync.RWMutex + cache = make(map[string]*vm.Program, 64) +) + +// compileEnvPrototype is the type-checking prototype used at compile time. +// It declares the shape of the environment that RunExpr will provide. +// The tier() function is a no-op placeholder here; the real one with +// side-channel tracing is injected at runtime. +var compileEnvPrototype = map[string]interface{}{ + "p": float64(0), + "c": float64(0), + "cr": float64(0), + "cc": float64(0), + "cc1h": float64(0), + "prompt_tokens": float64(0), + "completion_tokens": float64(0), + "cache_read_tokens": float64(0), + "cache_create_tokens": float64(0), + "cache_create_1h_tokens": float64(0), + "tier": func(string, float64) float64 { return 0 }, + "header": func(string) string { return "" }, + "param": func(string) interface{} { return nil }, + "has": func(interface{}, string) bool { return false }, + "hour": func(string) int { return 0 }, + "minute": func(string) int { return 0 }, + "weekday": func(string) int { return 0 }, + "month": func(string) int { return 0 }, + "day": func(string) int { return 0 }, + "max": math.Max, + "min": math.Min, + "abs": math.Abs, + "ceil": math.Ceil, + "floor": math.Floor, +} + +// CompileFromCache compiles an expression string, using a cached program when +// available. The cache is keyed by the SHA-256 hex digest of the expression. +func CompileFromCache(exprStr string) (*vm.Program, error) { + return compileFromCacheByHash(exprStr, ExprHashString(exprStr)) +} + +// CompileFromCacheByHash is like CompileFromCache but accepts a pre-computed +// hash, useful when the caller already has the BillingSnapshot.ExprHash. +func CompileFromCacheByHash(exprStr, hash string) (*vm.Program, error) { + return compileFromCacheByHash(exprStr, hash) +} + +func compileFromCacheByHash(exprStr, hash string) (*vm.Program, error) { + cacheMu.RLock() + if prog, ok := cache[hash]; ok { + cacheMu.RUnlock() + return prog, nil + } + cacheMu.RUnlock() + + prog, err := expr.Compile(exprStr, expr.Env(compileEnvPrototype), expr.AsFloat64()) + if err != nil { + return nil, fmt.Errorf("expr compile error: %w", err) + } + + cacheMu.Lock() + if len(cache) >= maxCacheSize { + cache = make(map[string]*vm.Program, 64) + } + cache[hash] = prog + cacheMu.Unlock() + + return prog, nil +} + +// InvalidateCache clears the compiled-expression cache. +// Called when billing rules are updated. +func InvalidateCache() { + cacheMu.Lock() + cache = make(map[string]*vm.Program, 64) + cacheMu.Unlock() +} diff --git a/pkg/billingexpr/round.go b/pkg/billingexpr/round.go new file mode 100644 index 00000000..35a5534a --- /dev/null +++ b/pkg/billingexpr/round.go @@ -0,0 +1,10 @@ +package billingexpr + +import "math" + +// QuotaRound converts a float64 quota value to int using half-away-from-zero +// rounding. Every tiered billing path (pre-consume, settlement, breakdown +// validation, log fields) MUST use this function to avoid +-1 discrepancies. +func QuotaRound(f float64) int { + return int(math.Round(f)) +} diff --git a/pkg/billingexpr/run.go b/pkg/billingexpr/run.go new file mode 100644 index 00000000..267c5af2 --- /dev/null +++ b/pkg/billingexpr/run.go @@ -0,0 +1,139 @@ +package billingexpr + +import ( + "fmt" + "math" + "strings" + "time" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + "github.com/tidwall/gjson" +) + +// RunExpr compiles (with cache) and executes an expression string. +// The environment exposes: +// - p, c — prompt / completion tokens +// - cr, cc, cc1h — cache read / creation / creation-1h tokens +// - tier(name, value) — trace callback that records which tier matched +// - max, min, abs, ceil, floor — standard math helpers +// +// Returns the resulting float64 quota (before group ratio) and a TraceResult +// with side-channel info captured by tier() during execution. +func RunExpr(exprStr string, params TokenParams) (float64, TraceResult, error) { + return RunExprWithRequest(exprStr, params, RequestInput{}) +} + +func RunExprWithRequest(exprStr string, params TokenParams, request RequestInput) (float64, TraceResult, error) { + prog, err := CompileFromCache(exprStr) + if err != nil { + return 0, TraceResult{}, err + } + return runProgram(prog, params, request) +} + +// RunExprByHash is like RunExpr but accepts a pre-computed hash for the cache +// lookup, avoiding a redundant SHA-256 computation when the caller already +// holds BillingSnapshot.ExprHash. +func RunExprByHash(exprStr, hash string, params TokenParams) (float64, TraceResult, error) { + return RunExprByHashWithRequest(exprStr, hash, params, RequestInput{}) +} + +func RunExprByHashWithRequest(exprStr, hash string, params TokenParams, request RequestInput) (float64, TraceResult, error) { + prog, err := CompileFromCacheByHash(exprStr, hash) + if err != nil { + return 0, TraceResult{}, err + } + return runProgram(prog, params, request) +} + +func runProgram(prog *vm.Program, params TokenParams, request RequestInput) (float64, TraceResult, error) { + trace := TraceResult{} + headers := normalizeHeaders(request.Headers) + + env := map[string]interface{}{ + "p": params.P, + "c": params.C, + "cr": params.CR, + "cc": params.CC, + "cc1h": params.CC1h, + "prompt_tokens": params.P, + "completion_tokens": params.C, + "cache_read_tokens": params.CR, + "cache_create_tokens": params.CC, + "cache_create_1h_tokens": params.CC1h, + "tier": func(name string, value float64) float64 { + trace.MatchedTier = name + trace.Cost = value + return value + }, + "header": func(key string) string { + return headers[strings.ToLower(strings.TrimSpace(key))] + }, + "param": func(path string) interface{} { + path = strings.TrimSpace(path) + if path == "" || len(request.Body) == 0 { + return nil + } + result := gjson.GetBytes(request.Body, path) + if !result.Exists() { + return nil + } + return result.Value() + }, + "has": func(source interface{}, substr string) bool { + if source == nil || substr == "" { + return false + } + return strings.Contains(fmt.Sprint(source), substr) + }, + "hour": func(tz string) int { return timeInZone(tz).Hour() }, + "minute": func(tz string) int { return timeInZone(tz).Minute() }, + "weekday": func(tz string) int { return int(timeInZone(tz).Weekday()) }, + "month": func(tz string) int { return int(timeInZone(tz).Month()) }, + "day": func(tz string) int { return timeInZone(tz).Day() }, + "max": math.Max, + "min": math.Min, + "abs": math.Abs, + "ceil": math.Ceil, + "floor": math.Floor, + } + + out, err := expr.Run(prog, env) + if err != nil { + return 0, trace, fmt.Errorf("expr run error: %w", err) + } + f, ok := out.(float64) + if !ok { + return 0, trace, fmt.Errorf("expr result is %T, want float64", out) + } + return f, trace, nil +} + +func timeInZone(tz string) time.Time { + tz = strings.TrimSpace(tz) + if tz == "" { + return time.Now().UTC() + } + loc, err := time.LoadLocation(tz) + if err != nil { + return time.Now().UTC() + } + return time.Now().In(loc) +} + +func normalizeHeaders(headers map[string]string) map[string]string { + if len(headers) == 0 { + return map[string]string{} + } + normalized := make(map[string]string, len(headers)) + for key, value := range headers { + k := strings.ToLower(strings.TrimSpace(key)) + v := strings.TrimSpace(value) + if k == "" || v == "" { + continue + } + normalized[k] = v + } + return normalized +} diff --git a/pkg/billingexpr/settle.go b/pkg/billingexpr/settle.go new file mode 100644 index 00000000..7e69b9e3 --- /dev/null +++ b/pkg/billingexpr/settle.go @@ -0,0 +1,24 @@ +package billingexpr + +// ComputeTieredQuota runs the Expr from a frozen BillingSnapshot against +// actual token counts and returns the settlement result. +func ComputeTieredQuota(snap *BillingSnapshot, params TokenParams) (TieredResult, error) { + return ComputeTieredQuotaWithRequest(snap, params, RequestInput{}) +} + +func ComputeTieredQuotaWithRequest(snap *BillingSnapshot, params TokenParams, request RequestInput) (TieredResult, error) { + cost, trace, err := RunExprByHashWithRequest(snap.ExprString, snap.ExprHash, params, request) + if err != nil { + return TieredResult{}, err + } + + afterGroup := QuotaRound(cost * snap.GroupRatio) + crossed := trace.MatchedTier != snap.EstimatedTier + + return TieredResult{ + ActualQuotaBeforeGroup: cost, + ActualQuotaAfterGroup: afterGroup, + MatchedTier: trace.MatchedTier, + CrossedTier: crossed, + }, nil +} diff --git a/pkg/billingexpr/types.go b/pkg/billingexpr/types.go new file mode 100644 index 00000000..193f82b4 --- /dev/null +++ b/pkg/billingexpr/types.go @@ -0,0 +1,59 @@ +package billingexpr + +import ( + "crypto/sha256" + "fmt" +) + +type RequestInput struct { + Headers map[string]string + Body []byte +} + +// TokenParams holds all token dimensions passed into an Expr evaluation. +// Fields beyond P and C are optional — when absent they default to 0, +// which means cache-unaware expressions keep working unchanged. +type TokenParams struct { + P float64 // prompt tokens + C float64 // completion tokens + CR float64 // cache read (hit) tokens + CC float64 // cache creation tokens (5-min TTL for Claude, generic for others) + CC1h float64 // cache creation tokens — 1-hour TTL (Claude only) +} + +// TraceResult holds side-channel info captured by the tier() function +// during Expr execution. This replaces the old Breakdown mechanism — +// the Expr itself is the single source of truth for billing logic. +type TraceResult struct { + MatchedTier string `json:"matched_tier"` + Cost float64 `json:"cost"` +} + +// BillingSnapshot captures the billing rule state frozen at pre-consume time. +// It is fully serializable and contains no compiled program pointers. +type BillingSnapshot struct { + BillingMode string `json:"billing_mode"` + ModelName string `json:"model_name"` + ExprString string `json:"expr_string"` + ExprHash string `json:"expr_hash"` + GroupRatio float64 `json:"group_ratio"` + EstimatedPromptTokens int `json:"estimated_prompt_tokens"` + EstimatedCompletionTokens int `json:"estimated_completion_tokens"` + EstimatedQuotaBeforeGroup float64 `json:"estimated_quota_before_group"` + EstimatedQuotaAfterGroup int `json:"estimated_quota_after_group"` + EstimatedTier string `json:"estimated_tier"` +} + +// TieredResult holds everything needed after running tiered settlement. +type TieredResult struct { + ActualQuotaBeforeGroup float64 `json:"actual_quota_before_group"` + ActualQuotaAfterGroup int `json:"actual_quota_after_group"` + MatchedTier string `json:"matched_tier"` + CrossedTier bool `json:"crossed_tier"` +} + +// ExprHashString returns the SHA-256 hex digest of an expression string. +func ExprHashString(expr string) string { + h := sha256.Sum256([]byte(expr)) + return fmt.Sprintf("%x", h) +} diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 5c34b792..e16ec29b 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -46,7 +46,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type resp, err := adaptor.DoRequest(c, info, ioReader) if err != nil { - return types.NewError(err, types.ErrorCodeDoRequestFailed) + return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } statusCodeMappingStr := c.GetString("status_code_mapping") diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go index 8f69b937..7a2eb9aa 100644 --- a/relay/chat_completions_via_responses.go +++ b/relay/chat_completions_via_responses.go @@ -2,6 +2,7 @@ package relay import ( "bytes" + "io" "net/http" "strings" @@ -124,8 +125,10 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } + var requestBody io.Reader = bytes.NewBuffer(jsonData) + var httpResp *http.Response - resp, err := adaptor.DoRequest(c, info, bytes.NewBuffer(jsonData)) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } diff --git a/relay/common/billing.go b/relay/common/billing.go index 78f5cb19..3971426b 100644 --- a/relay/common/billing.go +++ b/relay/common/billing.go @@ -18,4 +18,7 @@ type BillingSettler interface { // GetPreConsumedQuota 返回实际预扣的额度值(信任用户可能为 0)。 GetPreConsumedQuota() int + + // Reserve 将预扣额度补到目标值;若目标值不高于当前预扣额度则不做任何事。 + Reserve(targetQuota int) error } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 8b0789c0..a130c6f9 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -10,6 +10,7 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/pkg/billingexpr" relayconstant "github.com/QuantumNous/new-api/relay/constant" "github.com/QuantumNous/new-api/setting/model_setting" "github.com/QuantumNous/new-api/types" @@ -152,6 +153,11 @@ type RelayInfo struct { PriceData types.PriceData + // TieredBillingSnapshot is a frozen snapshot of tiered billing rules + // captured at pre-consume time. Non-nil only when billing mode is "tiered_expr". + TieredBillingSnapshot *billingexpr.BillingSnapshot + BillingRequestInput *billingexpr.RequestInput + Request dto.Request // RequestConversionChain records request format conversions in order, e.g. diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index f60a485b..392677af 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" "github.com/QuantumNous/new-api/relay/helper" @@ -236,6 +237,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat()) } + // Tiered billing early return + if ok, tieredQuota, tieredResult := service.TryTieredSettle(relayInfo, billingexpr.TokenParams{ + P: float64(usage.PromptTokens), + C: float64(usage.CompletionTokens), + CR: float64(usage.PromptTokensDetails.CachedTokens), + CC: float64(usage.PromptTokensDetails.CachedCreationTokens - usage.ClaudeCacheCreation1hTokens), + CC1h: float64(usage.ClaudeCacheCreation1hTokens), + }); ok { + postConsumeQuotaTiered(ctx, relayInfo, usage, tieredQuota, tieredResult, extraContent...) + return + } + adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason) useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() @@ -506,3 +519,51 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage Other: other, }) } + +func postConsumeQuotaTiered(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, quota int, tieredResult *service.TieredResultWrapper, extraContent ...string) { + _ = tieredResult // will be used for log enrichment + + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() + modelName := relayInfo.OriginModelName + tokenName := ctx.GetString("token_name") + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + + totalTokens := usage.PromptTokens + usage.CompletionTokens + + if totalTokens == 0 { + quota = 0 + extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)") + logger.LogError(ctx, fmt.Sprintf("tiered billing: total tokens is 0, userId %d, channelId %d, tokenId %d, model %s, pre-consumed %d", + relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) + } else { + if groupRatio != 0 && quota == 0 { + quota = 1 + } + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + + if err := service.SettleBilling(ctx, relayInfo, quota); err != nil { + logger.LogError(ctx, "error settling tiered billing: "+err.Error()) + } + + logModel := modelName + logContent := strings.Join(extraContent, ", ") + + other := service.GenerateTieredOtherInfo(ctx, relayInfo, tieredResult) + + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + ModelName: logModel, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) +} diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index d8ca4223..7a73aa6e 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -3,6 +3,7 @@ package relay import ( "bytes" "fmt" + "io" "net/http" "github.com/QuantumNous/new-api/common" @@ -58,7 +59,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * } logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData))) - requestBody := bytes.NewBuffer(jsonData) + var requestBody io.Reader = bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { diff --git a/relay/helper/billing_expr_request.go b/relay/helper/billing_expr_request.go new file mode 100644 index 00000000..404a348f --- /dev/null +++ b/relay/helper/billing_expr_request.go @@ -0,0 +1,54 @@ +package helper + +import ( + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/pkg/billingexpr" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/gin-gonic/gin" +) + +func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.RelayInfo) (billingexpr.RequestInput, error) { + input := billingexpr.RequestInput{} + if info != nil { + input.Headers = cloneStringMap(info.RequestHeaders) + } + + bodyBytes, err := readIncomingBillingExprBody(c) + if err != nil { + return billingexpr.RequestInput{}, err + } + input.Body = bodyBytes + return input, nil +} + +func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) { + if c == nil || c.Request == nil || !isJSONContentType(c.Request.Header.Get("Content-Type")) { + return nil, nil + } + storage, err := common.GetBodyStorage(c) + if err != nil { + return nil, err + } + return storage.Bytes() +} + +func isJSONContentType(contentType string) bool { + contentType = strings.ToLower(strings.TrimSpace(contentType)) + return strings.HasPrefix(contentType, "application/json") +} + +func cloneStringMap(src map[string]string) map[string]string { + if len(src) == 0 { + return map[string]string{} + } + dst := make(map[string]string, len(src)) + for key, value := range src { + if strings.TrimSpace(key) == "" { + continue + } + dst[key] = value + } + return dst +} diff --git a/relay/helper/billing_expr_request_test.go b/relay/helper/billing_expr_request_test.go new file mode 100644 index 00000000..c07aaa29 --- /dev/null +++ b/relay/helper/billing_expr_request_test.go @@ -0,0 +1,35 @@ +package helper + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestResolveIncomingBillingExprRequestInput(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + ctx.Request.Header.Set("Content-Type", "application/json") + + body := []byte(`{"service_tier":"fast"}`) + ctx.Request.Body = io.NopCloser(bytes.NewReader(body)) + ctx.Set(common.KeyRequestBody, body) + + info := &relaycommon.RelayInfo{ + RequestHeaders: map[string]string{"Content-Type": "application/json"}, + } + + input, err := ResolveIncomingBillingExprRequestInput(ctx, info) + require.NoError(t, err) + require.Equal(t, body, input.Body) + require.Equal(t, "application/json", input.Headers["Content-Type"]) +} diff --git a/relay/helper/price.go b/relay/helper/price.go index f109040d..798f4e8d 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -5,7 +5,9 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/billing_setting" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/QuantumNous/new-api/types" @@ -50,6 +52,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens groupRatioInfo := HandleGroupRatio(c, info) + // Check if this model uses tiered_expr billing + if billing_setting.GetBillingMode(info.OriginModelName) == billing_setting.BillingModeTieredExpr { + return modelPriceHelperTiered(c, info, promptTokens, meta, groupRatioInfo) + } + var preConsumedQuota int var modelRatio float64 var completionRatio float64 @@ -195,5 +202,73 @@ func ContainPriceOrRatio(modelName string) bool { if ok { return true } + if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr { + _, ok = billing_setting.GetBillingExpr(modelName) + return ok + } return false } + +func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) { + exprStr, ok := billing_setting.GetBillingExpr(info.OriginModelName) + if !ok { + return types.PriceData{}, fmt.Errorf("model %s is configured as tiered_expr but has no billing expression", info.OriginModelName) + } + + estimatedCompletionTokens := 0 + if meta.MaxTokens != 0 { + estimatedCompletionTokens = meta.MaxTokens + } + + requestInput, err := ResolveIncomingBillingExprRequestInput(c, info) + if err != nil { + return types.PriceData{}, err + } + + rawQuota, trace, err := billingexpr.RunExprWithRequest(exprStr, billingexpr.TokenParams{ + P: float64(promptTokens), + C: float64(estimatedCompletionTokens), + }, requestInput) + if err != nil { + return types.PriceData{}, fmt.Errorf("model %s tiered expr run failed: %w", info.OriginModelName, err) + } + + preConsumedQuota := billingexpr.QuotaRound(rawQuota * groupRatioInfo.GroupRatio) + + freeModel := false + if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume { + if groupRatioInfo.GroupRatio == 0 || rawQuota == 0 { + preConsumedQuota = 0 + freeModel = true + } + } + + exprHash := billingexpr.ExprHashString(exprStr) + snapshot := &billingexpr.BillingSnapshot{ + BillingMode: billing_setting.BillingModeTieredExpr, + ModelName: info.OriginModelName, + ExprString: exprStr, + ExprHash: exprHash, + GroupRatio: groupRatioInfo.GroupRatio, + EstimatedPromptTokens: promptTokens, + EstimatedCompletionTokens: estimatedCompletionTokens, + EstimatedQuotaBeforeGroup: rawQuota, + EstimatedQuotaAfterGroup: preConsumedQuota, + EstimatedTier: trace.MatchedTier, + } + info.TieredBillingSnapshot = snapshot + info.BillingRequestInput = &requestInput + + priceData := types.PriceData{ + FreeModel: freeModel, + GroupRatioInfo: groupRatioInfo, + QuotaToPreConsume: preConsumedQuota, + } + + if common.DebugEnabled { + println(fmt.Sprintf("model_price_helper_tiered result: model=%s preConsume=%d rawQuota=%.2f groupRatio=%.2f tier=%s", info.OriginModelName, preConsumedQuota, rawQuota, groupRatioInfo.GroupRatio, trace.MatchedTier)) + } + + info.PriceData = priceData + return priceData, nil +} diff --git a/service/billing_session.go b/service/billing_session.go index f24b68e5..3e3ef1f4 100644 --- a/service/billing_session.go +++ b/service/billing_session.go @@ -27,6 +27,8 @@ type BillingSession struct { funding FundingSource preConsumedQuota int // 实际预扣额度(信任用户可能为 0) tokenConsumed int // 令牌额度实际扣减量 + extraReserved int // 发送前补充预扣的额度(订阅退款时需要单独回滚) + trusted bool // 是否命中信任额度旁路 fundingSettled bool // funding.Settle 已成功,资金来源已提交 settled bool // Settle 全部完成(资金 + 令牌) refunded bool // Refund 已调用 @@ -97,6 +99,8 @@ func (s *BillingSession) Refund(c *gin.Context) { tokenKey := s.relayInfo.TokenKey isPlayground := s.relayInfo.IsPlayground tokenConsumed := s.tokenConsumed + extraReserved := s.extraReserved + subscriptionId := s.relayInfo.SubscriptionId funding := s.funding gopool.Go(func() { @@ -104,6 +108,11 @@ func (s *BillingSession) Refund(c *gin.Context) { if err := funding.Refund(); err != nil { common.SysLog("error refunding billing source: " + err.Error()) } + if extraReserved > 0 && funding.Source() == BillingSourceSubscription && subscriptionId > 0 { + if err := model.PostConsumeUserSubscriptionDelta(subscriptionId, -int64(extraReserved)); err != nil { + common.SysLog("error refunding subscription extra reserved quota: " + err.Error()) + } + } // 2) 退还令牌额度 if tokenConsumed > 0 && !isPlayground { if err := model.IncreaseTokenQuota(tokenId, tokenKey, tokenConsumed); err != nil { @@ -140,6 +149,34 @@ func (s *BillingSession) GetPreConsumedQuota() int { return s.preConsumedQuota } +func (s *BillingSession) Reserve(targetQuota int) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.settled || s.refunded || s.trusted || targetQuota <= s.preConsumedQuota { + return nil + } + + delta := targetQuota - s.preConsumedQuota + if delta <= 0 { + return nil + } + + if err := s.reserveFunding(delta); err != nil { + return err + } + if err := s.reserveToken(delta); err != nil { + s.rollbackFundingReserve(delta) + return err + } + + s.preConsumedQuota += delta + s.tokenConsumed += delta + s.extraReserved += delta + s.syncRelayInfo() + return nil +} + // --------------------------------------------------------------------------- // PreConsume — 统一预扣费入口(含信任额度旁路) // --------------------------------------------------------------------------- @@ -151,6 +188,7 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro // ---- 信任额度旁路 ---- if s.shouldTrust(c) { + s.trusted = true effectiveQuota = 0 logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足, 信任且不需要预扣费 (funding=%s)", s.relayInfo.UserId, s.funding.Source())) } else if effectiveQuota > 0 { @@ -191,6 +229,55 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro return nil } +func (s *BillingSession) reserveFunding(delta int) error { + switch funding := s.funding.(type) { + case *WalletFunding: + if err := model.DecreaseUserQuota(funding.userId, delta); err != nil { + return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + } + funding.consumed += delta + return nil + case *SubscriptionFunding: + if err := model.PostConsumeUserSubscriptionDelta(funding.subscriptionId, int64(delta)); err != nil { + return types.NewErrorWithStatusCode( + fmt.Errorf("订阅额度不足或未配置订阅: %s", err.Error()), + types.ErrorCodeInsufficientUserQuota, + http.StatusForbidden, + types.ErrOptionWithSkipRetry(), + types.ErrOptionWithNoRecordErrorLog(), + ) + } + return nil + default: + return types.NewError(fmt.Errorf("unsupported funding source: %s", s.funding.Source()), types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + } +} + +func (s *BillingSession) rollbackFundingReserve(delta int) { + switch funding := s.funding.(type) { + case *WalletFunding: + if err := model.IncreaseUserQuota(funding.userId, delta, false); err != nil { + common.SysLog("error rolling back wallet funding reserve: " + err.Error()) + } else { + funding.consumed -= delta + } + case *SubscriptionFunding: + if err := model.PostConsumeUserSubscriptionDelta(funding.subscriptionId, -int64(delta)); err != nil { + common.SysLog("error rolling back subscription funding reserve: " + err.Error()) + } + } +} + +func (s *BillingSession) reserveToken(delta int) error { + if delta <= 0 || s.relayInfo.IsPlayground { + return nil + } + if err := PreConsumeTokenQuota(s.relayInfo, delta); err != nil { + return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + return nil +} + // shouldTrust 统一信任额度检查,适用于钱包和订阅。 func (s *BillingSession) shouldTrust(c *gin.Context) bool { // 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路 @@ -235,10 +322,10 @@ func (s *BillingSession) syncRelayInfo() { if sub, ok := s.funding.(*SubscriptionFunding); ok { info.SubscriptionId = sub.subscriptionId - info.SubscriptionPreConsumed = sub.preConsumed + info.SubscriptionPreConsumed = sub.preConsumed + int64(s.extraReserved) info.SubscriptionPostDelta = 0 info.SubscriptionAmountTotal = sub.AmountTotal - info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter + info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter + int64(s.extraReserved) info.SubscriptionPlanId = sub.PlanId info.SubscriptionPlanTitle = sub.PlanTitle } else { diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 1c440911..0737a911 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -6,6 +6,7 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/types" @@ -214,3 +215,42 @@ func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.Price appendRequestPath(nil, relayInfo, other) return other } + +func GenerateTieredOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, result *billingexpr.TieredResult) map[string]interface{} { + other := make(map[string]interface{}) + other["billing_mode"] = "tiered_expr" + + snap := relayInfo.TieredBillingSnapshot + if snap != nil { + other["group_ratio"] = snap.GroupRatio + other["expr_hash"] = snap.ExprHash + other["estimated_prompt_tokens"] = snap.EstimatedPromptTokens + other["estimated_completion_tokens"] = snap.EstimatedCompletionTokens + other["estimated_quota_before_group"] = snap.EstimatedQuotaBeforeGroup + other["estimated_quota_after_group"] = snap.EstimatedQuotaAfterGroup + other["estimated_tier"] = snap.EstimatedTier + } + + if result != nil { + other["actual_quota_before_group"] = result.ActualQuotaBeforeGroup + other["actual_quota_after_group"] = result.ActualQuotaAfterGroup + other["matched_tier"] = result.MatchedTier + other["crossed_tier"] = result.CrossedTier + } + + other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli()) + if relayInfo.IsModelMapped { + other["is_model_mapped"] = true + other["upstream_model_name"] = relayInfo.UpstreamModelName + } + + adminInfo := make(map[string]interface{}) + adminInfo["use_channel"] = ctx.GetStringSlice("use_channel") + AppendChannelAffinityAdminInfo(ctx, adminInfo) + other["admin_info"] = adminInfo + + appendRequestPath(ctx, relayInfo, other) + appendRequestConversionChain(relayInfo, other) + appendBillingInfo(relayInfo, other) + return other +} diff --git a/service/quota.go b/service/quota.go index 7ee70edd..eafec351 100644 --- a/service/quota.go +++ b/service/quota.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/pkg/billingexpr" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/QuantumNous/new-api/setting/system_setting" @@ -157,6 +158,15 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, usage *dto.RealtimeUsage, extraContent string) { + // Tiered billing early return + if ok, tieredQuota, tieredResult := TryTieredSettle(relayInfo, billingexpr.TokenParams{ + P: float64(usage.InputTokens), + C: float64(usage.OutputTokens), + }); ok { + postConsumeQuotaTieredService(ctx, relayInfo, modelName, usage.InputTokens, usage.OutputTokens, usage.TotalTokens, tieredQuota, tieredResult, extraContent) + return + } + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.InputTokenDetails.TextTokens textOutTokens := usage.OutputTokenDetails.TextTokens @@ -240,6 +250,18 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat()) } + // Tiered billing early return + if ok, tieredQuota, tieredResult := TryTieredSettle(relayInfo, billingexpr.TokenParams{ + P: float64(usage.PromptTokens), + C: float64(usage.CompletionTokens), + CR: float64(usage.PromptTokensDetails.CachedTokens), + CC: float64(usage.PromptTokensDetails.CachedCreationTokens - usage.ClaudeCacheCreation1hTokens), + CC1h: float64(usage.ClaudeCacheCreation1hTokens), + }); ok { + postConsumeQuotaTieredService(ctx, relayInfo, relayInfo.OriginModelName, usage.PromptTokens, usage.CompletionTokens, usage.PromptTokens+usage.CompletionTokens, tieredQuota, tieredResult, "") + return + } + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens @@ -360,6 +382,16 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { + // Tiered billing early return + if ok, tieredQuota, tieredResult := TryTieredSettle(relayInfo, billingexpr.TokenParams{ + P: float64(usage.PromptTokens), + C: float64(usage.CompletionTokens), + CR: float64(usage.PromptTokensDetails.CachedTokens), + }); ok { + postConsumeQuotaTieredService(ctx, relayInfo, relayInfo.OriginModelName, usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens, tieredQuota, tieredResult, extraContent) + return + } + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.PromptTokensDetails.TextTokens textOutTokens := usage.CompletionTokenDetails.TextTokens @@ -607,3 +639,50 @@ func checkAndSendSubscriptionQuotaNotify(relayInfo *relaycommon.RelayInfo) { } }) } + +func postConsumeQuotaTieredService(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, + promptTokens, completionTokens, totalTokens, quota int, tieredResult *TieredResultWrapper, extraContent string) { + + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() + tokenName := ctx.GetString("token_name") + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + + var logContent string + if totalTokens == 0 { + quota = 0 + logContent = "上游没有返回计费信息(可能是上游超时)" + logger.LogError(ctx, fmt.Sprintf("tiered billing: total tokens is 0, userId %d, channelId %d, tokenId %d, model %s, pre-consumed %d", + relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) + } else { + if groupRatio != 0 && quota == 0 { + quota = 1 + } + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + + if err := SettleBilling(ctx, relayInfo, quota); err != nil { + logger.LogError(ctx, "error settling tiered billing: "+err.Error()) + } + + if extraContent != "" { + logContent += extraContent + } + + other := GenerateTieredOtherInfo(ctx, relayInfo, tieredResult) + + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + ModelName: modelName, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) +} diff --git a/service/tiered_settle.go b/service/tiered_settle.go new file mode 100644 index 00000000..83f5b6f6 --- /dev/null +++ b/service/tiered_settle.go @@ -0,0 +1,36 @@ +package service + +import ( + "github.com/QuantumNous/new-api/pkg/billingexpr" + relaycommon "github.com/QuantumNous/new-api/relay/common" +) + +// TieredResultWrapper wraps billingexpr.TieredResult for use at the service layer. +type TieredResultWrapper = billingexpr.TieredResult + +// TryTieredSettle checks if the request uses tiered_expr billing and, if so, +// computes the actual quota using the frozen BillingSnapshot. Returns: +// - ok=true, quota, result when tiered billing applies +// - ok=false, 0, nil when it doesn't (caller should fall through to existing logic) +func TryTieredSettle(relayInfo *relaycommon.RelayInfo, params billingexpr.TokenParams) (ok bool, quota int, result *billingexpr.TieredResult) { + snap := relayInfo.TieredBillingSnapshot + if snap == nil || snap.BillingMode != "tiered_expr" { + return false, 0, nil + } + + requestInput := billingexpr.RequestInput{} + if relayInfo.BillingRequestInput != nil { + requestInput = *relayInfo.BillingRequestInput + } + + tr, err := billingexpr.ComputeTieredQuotaWithRequest(snap, params, requestInput) + if err != nil { + quota = relayInfo.FinalPreConsumedQuota + if quota <= 0 { + quota = snap.EstimatedQuotaAfterGroup + } + return true, quota, nil + } + + return true, tr.ActualQuotaAfterGroup, &tr +} diff --git a/service/tiered_settle_test.go b/service/tiered_settle_test.go new file mode 100644 index 00000000..4ec702ad --- /dev/null +++ b/service/tiered_settle_test.go @@ -0,0 +1,401 @@ +package service + +import ( + "testing" + + "github.com/QuantumNous/new-api/pkg/billingexpr" + relaycommon "github.com/QuantumNous/new-api/relay/common" +) + +// Claude Sonnet-style tiered expression: standard vs long-context +const sonnetTieredExpr = `p <= 200000 ? tier("standard", p * 1.5 + c * 7.5) : tier("long_context", p * 3 + c * 11.25)` + +// Simple flat expression +const flatExpr = `tier("default", p * 2 + c * 10)` + +// Expression with cache tokens +const cacheExpr = `tier("default", p * 2 + c * 10 + cr * 0.2 + cc * 2.5 + cc1h * 4)` + +// Expression with request probes +const probeExpr = `param("service_tier") == "fast" ? tier("fast", p * 4 + c * 20) : tier("normal", p * 2 + c * 10)` + +func makeSnapshot(expr string, groupRatio float64, estPrompt, estCompletion int) *billingexpr.BillingSnapshot { + return &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: expr, + ExprHash: billingexpr.ExprHashString(expr), + GroupRatio: groupRatio, + EstimatedPromptTokens: estPrompt, + EstimatedCompletionTokens: estCompletion, + } +} + +func makeRelayInfo(expr string, groupRatio float64, estPrompt, estCompletion int) *relaycommon.RelayInfo { + snap := makeSnapshot(expr, groupRatio, estPrompt, estCompletion) + cost, trace, _ := billingexpr.RunExpr(expr, billingexpr.TokenParams{P: float64(estPrompt), C: float64(estCompletion)}) + snap.EstimatedQuotaBeforeGroup = cost + snap.EstimatedQuotaAfterGroup = billingexpr.QuotaRound(cost * groupRatio) + snap.EstimatedTier = trace.MatchedTier + return &relaycommon.RelayInfo{ + TieredBillingSnapshot: snap, + FinalPreConsumedQuota: snap.EstimatedQuotaAfterGroup, + } +} + +// --------------------------------------------------------------------------- +// Existing tests (preserved) +// --------------------------------------------------------------------------- + +func TestTryTieredSettleUsesFrozenRequestInput(t *testing.T) { + exprStr := `param("service_tier") == "fast" ? tier("fast", p * 2) : tier("normal", p)` + relayInfo := &relaycommon.RelayInfo{ + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: exprStr, + ExprHash: billingexpr.ExprHashString(exprStr), + GroupRatio: 1.0, + EstimatedPromptTokens: 100, + EstimatedCompletionTokens: 0, + EstimatedQuotaAfterGroup: 100, + }, + BillingRequestInput: &billingexpr.RequestInput{ + Body: []byte(`{"service_tier":"fast"}`), + }, + } + + ok, quota, result := TryTieredSettle(relayInfo, billingexpr.TokenParams{P: 100}) + if !ok { + t.Fatal("expected tiered settle to apply") + } + if quota != 200 { + t.Fatalf("quota = %d, want 200", quota) + } + if result == nil || result.MatchedTier != "fast" { + t.Fatalf("matched tier = %v, want fast", result) + } +} + +func TestTryTieredSettleFallsBackToFrozenPreConsumeOnExprError(t *testing.T) { + relayInfo := &relaycommon.RelayInfo{ + FinalPreConsumedQuota: 321, + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: `invalid +-+ expr`, + ExprHash: billingexpr.ExprHashString(`invalid +-+ expr`), + GroupRatio: 1.0, + EstimatedQuotaAfterGroup: 123, + }, + } + + ok, quota, result := TryTieredSettle(relayInfo, billingexpr.TokenParams{P: 100}) + if !ok { + t.Fatal("expected tiered settle to apply") + } + if quota != 321 { + t.Fatalf("quota = %d, want 321", quota) + } + if result != nil { + t.Fatalf("result = %#v, want nil", result) + } +} + +// --------------------------------------------------------------------------- +// Pre-consume vs Post-consume consistency +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_PreConsumeMatchesPostConsume(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.0, 1000, 500) + params := billingexpr.TokenParams{P: 1000, C: 500} + + ok, quota, _ := TryTieredSettle(info, params) + if !ok { + t.Fatal("expected tiered settle") + } + // p*2 + c*10 = 2000 + 5000 = 7000 + if quota != 7000 { + t.Fatalf("quota = %d, want 7000", quota) + } + if quota != info.FinalPreConsumedQuota { + t.Fatalf("pre-consume %d != post-consume %d", info.FinalPreConsumedQuota, quota) + } +} + +func TestTryTieredSettle_PostConsumeOverPreConsume(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.0, 1000, 500) + preConsumed := info.FinalPreConsumedQuota // 7000 + + // Actual usage is higher than estimated + params := billingexpr.TokenParams{P: 2000, C: 1000} + ok, quota, _ := TryTieredSettle(info, params) + if !ok { + t.Fatal("expected tiered settle") + } + // p*2 + c*10 = 4000 + 10000 = 14000 + if quota != 14000 { + t.Fatalf("quota = %d, want 14000", quota) + } + if quota <= preConsumed { + t.Fatalf("expected supplement: actual %d should > pre-consumed %d", quota, preConsumed) + } +} + +func TestTryTieredSettle_PostConsumeUnderPreConsume(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.0, 1000, 500) + preConsumed := info.FinalPreConsumedQuota // 7000 + + // Actual usage is lower than estimated + params := billingexpr.TokenParams{P: 100, C: 50} + ok, quota, _ := TryTieredSettle(info, params) + if !ok { + t.Fatal("expected tiered settle") + } + // p*2 + c*10 = 200 + 500 = 700 + if quota != 700 { + t.Fatalf("quota = %d, want 700", quota) + } + if quota >= preConsumed { + t.Fatalf("expected refund: actual %d should < pre-consumed %d", quota, preConsumed) + } +} + +// --------------------------------------------------------------------------- +// Tiered boundary conditions +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_ExactBoundary(t *testing.T) { + info := makeRelayInfo(sonnetTieredExpr, 1.0, 200000, 1000) + + // p == 200000 => standard tier (p <= 200000) + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 200000, C: 1000}) + if !ok { + t.Fatal("expected tiered settle") + } + // standard: p*1.5 + c*7.5 = 300000 + 7500 = 307500 + if quota != 307500 { + t.Fatalf("quota = %d, want 307500", quota) + } + if result.MatchedTier != "standard" { + t.Fatalf("tier = %s, want standard", result.MatchedTier) + } +} + +func TestTryTieredSettle_BoundaryPlusOne(t *testing.T) { + info := makeRelayInfo(sonnetTieredExpr, 1.0, 200000, 1000) + + // p == 200001 => crosses to long_context tier + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 200001, C: 1000}) + if !ok { + t.Fatal("expected tiered settle") + } + // long_context: p*3 + c*11.25 = 600003 + 11250 = 611253 + if quota != 611253 { + t.Fatalf("quota = %d, want 611253", quota) + } + if result.MatchedTier != "long_context" { + t.Fatalf("tier = %s, want long_context", result.MatchedTier) + } + if !result.CrossedTier { + t.Fatal("expected CrossedTier = true") + } +} + +func TestTryTieredSettle_ZeroTokens(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.0, 0, 0) + + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 0, C: 0}) + if !ok { + t.Fatal("expected tiered settle") + } + if quota != 0 { + t.Fatalf("quota = %d, want 0", quota) + } + if result == nil { + t.Fatal("result should not be nil") + } +} + +func TestTryTieredSettle_HugeTokens(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.0, 10000000, 5000000) + + ok, quota, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 10000000, C: 5000000}) + if !ok { + t.Fatal("expected tiered settle") + } + // p*2 + c*10 = 20000000 + 50000000 = 70000000 + if quota != 70000000 { + t.Fatalf("quota = %d, want 70000000", quota) + } +} + +func TestTryTieredSettle_CacheTokensAffectSettlement(t *testing.T) { + info := makeRelayInfo(cacheExpr, 1.0, 1000, 500) + + // Without cache tokens + ok1, quota1, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if !ok1 { + t.Fatal("expected tiered settle") + } + // p*2 + c*10 + cr*0.2 + cc*2.5 + cc1h*4 = 2000 + 5000 + 0 + 0 + 0 = 7000 + + // With cache tokens + ok2, quota2, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500, CR: 10000, CC: 5000, CC1h: 2000}) + if !ok2 { + t.Fatal("expected tiered settle") + } + // 2000 + 5000 + 10000*0.2 + 5000*2.5 + 2000*4 = 2000 + 5000 + 2000 + 12500 + 8000 = 29500 + + if quota2 <= quota1 { + t.Fatalf("cache tokens should increase quota: without=%d, with=%d", quota1, quota2) + } + if quota1 != 7000 { + t.Fatalf("no-cache quota = %d, want 7000", quota1) + } + if quota2 != 29500 { + t.Fatalf("cache quota = %d, want 29500", quota2) + } +} + +// --------------------------------------------------------------------------- +// Request probe tests +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_RequestProbeInfluencesBilling(t *testing.T) { + info := makeRelayInfo(probeExpr, 1.0, 1000, 500) + info.BillingRequestInput = &billingexpr.RequestInput{ + Body: []byte(`{"service_tier":"fast"}`), + } + + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if !ok { + t.Fatal("expected tiered settle") + } + // fast: p*4 + c*20 = 4000 + 10000 = 14000 + if quota != 14000 { + t.Fatalf("quota = %d, want 14000", quota) + } + if result.MatchedTier != "fast" { + t.Fatalf("tier = %s, want fast", result.MatchedTier) + } +} + +func TestTryTieredSettle_NoRequestInput_FallsBackToDefault(t *testing.T) { + info := makeRelayInfo(probeExpr, 1.0, 1000, 500) + // No BillingRequestInput set — param("service_tier") returns nil, not "fast" + + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if !ok { + t.Fatal("expected tiered settle") + } + // normal: p*2 + c*10 = 2000 + 5000 = 7000 + if quota != 7000 { + t.Fatalf("quota = %d, want 7000", quota) + } + if result.MatchedTier != "normal" { + t.Fatalf("tier = %s, want normal", result.MatchedTier) + } +} + +// --------------------------------------------------------------------------- +// Group ratio tests +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_GroupRatioScaling(t *testing.T) { + info := makeRelayInfo(flatExpr, 1.5, 1000, 500) + + ok, quota, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if !ok { + t.Fatal("expected tiered settle") + } + // cost = 7000, after group = round(7000 * 1.5) = 10500 + if quota != 10500 { + t.Fatalf("quota = %d, want 10500", quota) + } +} + +func TestTryTieredSettle_GroupRatioZero(t *testing.T) { + info := makeRelayInfo(flatExpr, 0, 1000, 500) + + ok, quota, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if !ok { + t.Fatal("expected tiered settle") + } + if quota != 0 { + t.Fatalf("quota = %d, want 0 (group ratio = 0)", quota) + } +} + +// --------------------------------------------------------------------------- +// Ratio mode (negative tests) — TryTieredSettle must return false +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_RatioMode_NilSnapshot(t *testing.T) { + info := &relaycommon.RelayInfo{ + TieredBillingSnapshot: nil, + } + + ok, _, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if ok { + t.Fatal("expected TryTieredSettle to return false when snapshot is nil") + } +} + +func TestTryTieredSettle_RatioMode_WrongBillingMode(t *testing.T) { + info := &relaycommon.RelayInfo{ + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "ratio", + ExprString: flatExpr, + ExprHash: billingexpr.ExprHashString(flatExpr), + GroupRatio: 1.0, + }, + } + + ok, _, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if ok { + t.Fatal("expected TryTieredSettle to return false for ratio billing mode") + } +} + +func TestTryTieredSettle_RatioMode_EmptyBillingMode(t *testing.T) { + info := &relaycommon.RelayInfo{ + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "", + ExprString: flatExpr, + ExprHash: billingexpr.ExprHashString(flatExpr), + GroupRatio: 1.0, + }, + } + + ok, _, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500}) + if ok { + t.Fatal("expected TryTieredSettle to return false for empty billing mode") + } +} + +// --------------------------------------------------------------------------- +// Fallback tests +// --------------------------------------------------------------------------- + +func TestTryTieredSettle_ErrorFallbackToEstimatedQuotaAfterGroup(t *testing.T) { + info := &relaycommon.RelayInfo{ + FinalPreConsumedQuota: 0, + TieredBillingSnapshot: &billingexpr.BillingSnapshot{ + BillingMode: "tiered_expr", + ExprString: `invalid expr!!!`, + ExprHash: billingexpr.ExprHashString(`invalid expr!!!`), + GroupRatio: 1.0, + EstimatedQuotaAfterGroup: 999, + }, + } + + ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 100}) + if !ok { + t.Fatal("expected tiered settle to apply") + } + // FinalPreConsumedQuota is 0, should fall back to EstimatedQuotaAfterGroup + if quota != 999 { + t.Fatalf("quota = %d, want 999", quota) + } + if result != nil { + t.Fatal("result should be nil on error fallback") + } +} diff --git a/setting/billing_setting/tiered_billing.go b/setting/billing_setting/tiered_billing.go new file mode 100644 index 00000000..85d5e628 --- /dev/null +++ b/setting/billing_setting/tiered_billing.go @@ -0,0 +1,135 @@ +package billing_setting + +import ( + "fmt" + "sync" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/pkg/billingexpr" +) + +var ( + mu sync.RWMutex + + // model -> "ratio" | "tiered_expr" + billingModeMap = make(map[string]string) + + // model -> expr string (authored by frontend, stored directly) + billingExprMap = make(map[string]string) +) + +const ( + BillingModeRatio = "ratio" + BillingModeTieredExpr = "tiered_expr" +) + +// --------------------------------------------------------------------------- +// Read accessors (hot path, must be fast) +// --------------------------------------------------------------------------- + +func GetBillingMode(model string) string { + mu.RLock() + defer mu.RUnlock() + if mode, ok := billingModeMap[model]; ok { + return mode + } + return BillingModeRatio +} + +func GetBillingExpr(model string) (string, bool) { + mu.RLock() + defer mu.RUnlock() + expr, ok := billingExprMap[model] + return expr, ok +} + +func UpdateBillingModeByJSONString(jsonStr string) error { + var m map[string]string + if err := common.Unmarshal([]byte(jsonStr), &m); err != nil { + return fmt.Errorf("parse ModelBillingMode: %w", err) + } + for k, v := range m { + if v != BillingModeRatio && v != BillingModeTieredExpr { + return fmt.Errorf("invalid billing mode %q for model %q", v, k) + } + } + mu.Lock() + billingModeMap = m + mu.Unlock() + return nil +} + +func UpdateBillingExprByJSONString(jsonStr string) error { + var m map[string]string + if err := common.Unmarshal([]byte(jsonStr), &m); err != nil { + return fmt.Errorf("parse ModelBillingExpr: %w", err) + } + for model, exprStr := range m { + if _, err := billingexpr.CompileFromCache(exprStr); err != nil { + return fmt.Errorf("model %q: %w", model, err) + } + if err := smokeTestExpr(exprStr); err != nil { + return fmt.Errorf("model %q smoke test: %w", model, err) + } + } + mu.Lock() + billingExprMap = m + mu.Unlock() + billingexpr.InvalidateCache() + return nil +} + +// --------------------------------------------------------------------------- +// JSON serializers (for OptionMap / API response) +// --------------------------------------------------------------------------- + +func BillingMode2JSONString() string { + mu.RLock() + defer mu.RUnlock() + b, err := common.Marshal(billingModeMap) + if err != nil { + return "{}" + } + return string(b) +} + +func BillingExpr2JSONString() string { + mu.RLock() + defer mu.RUnlock() + b, err := common.Marshal(billingExprMap) + if err != nil { + return "{}" + } + return string(b) +} + +func smokeTestExpr(exprStr string) error { + vectors := []billingexpr.TokenParams{ + {P: 0, C: 0}, + {P: 1000, C: 1000}, + {P: 100000, C: 100000}, + {P: 1000000, C: 1000000}, + } + requests := []billingexpr.RequestInput{ + {}, + { + Headers: map[string]string{ + "anthropic-beta": "fast-mode-2026-02-01", + }, + Body: []byte(`{"service_tier":"fast","stream_options":{"include_usage":true},"messages":[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]}`), + }, + } + + for _, v := range vectors { + for _, request := range requests { + result, _, err := billingexpr.RunExprWithRequest(exprStr, v, request) + if err != nil { + return fmt.Errorf("vector {p=%g, c=%g}: run failed: %w", v.P, v.C, err) + } + if result < 0 { + return fmt.Errorf("vector {p=%g, c=%g}: result %f < 0", v.P, v.C, result) + } + } + } + return nil +} diff --git a/setting/model_setting/claude_test.go b/setting/model_setting/claude_test.go new file mode 100644 index 00000000..0a806a7a --- /dev/null +++ b/setting/model_setting/claude_test.go @@ -0,0 +1,60 @@ +package model_setting + +import ( + "net/http" + "testing" +) + +func TestClaudeSettingsWriteHeadersMergesConfiguredValuesIntoSingleHeader(t *testing.T) { + settings := &ClaudeSettings{ + HeadersSettings: map[string]map[string][]string{ + "claude-3-7-sonnet-20250219-thinking": { + "anthropic-beta": { + "token-efficient-tools-2025-02-19", + }, + }, + }, + } + + headers := http.Header{} + headers.Set("anthropic-beta", "output-128k-2025-02-19") + + settings.WriteHeaders("claude-3-7-sonnet-20250219-thinking", &headers) + + got := headers.Values("anthropic-beta") + if len(got) != 1 { + t.Fatalf("expected a single merged header value, got %v", got) + } + expected := "output-128k-2025-02-19,token-efficient-tools-2025-02-19" + if got[0] != expected { + t.Fatalf("expected merged header %q, got %q", expected, got[0]) + } +} + +func TestClaudeSettingsWriteHeadersDeduplicatesAcrossCommaSeparatedAndRepeatedValues(t *testing.T) { + settings := &ClaudeSettings{ + HeadersSettings: map[string]map[string][]string{ + "claude-3-7-sonnet-20250219-thinking": { + "anthropic-beta": { + "token-efficient-tools-2025-02-19", + "computer-use-2025-01-24", + }, + }, + }, + } + + headers := http.Header{} + headers.Add("anthropic-beta", "output-128k-2025-02-19, token-efficient-tools-2025-02-19") + headers.Add("anthropic-beta", "token-efficient-tools-2025-02-19") + + settings.WriteHeaders("claude-3-7-sonnet-20250219-thinking", &headers) + + got := headers.Values("anthropic-beta") + if len(got) != 1 { + t.Fatalf("expected duplicate values to collapse into one header, got %v", got) + } + expected := "output-128k-2025-02-19,token-efficient-tools-2025-02-19,computer-use-2025-01-24" + if got[0] != expected { + t.Fatalf("expected deduplicated merged header %q, got %q", expected, got[0]) + } +} diff --git a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx index 61104d22..6f5021b4 100644 --- a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx +++ b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx @@ -377,6 +377,43 @@ function renderCompactDetailSummary(summarySegments) { ); } +function buildTieredBillingSegments(other, t) { + const segments = [ + { text: `${t('阶梯计费')}`, tone: 'primary' }, + ]; + + if (other.matched_tier) { + segments.push({ + text: `${t('命中档位')}: ${other.matched_tier}`, + tone: 'secondary', + }); + } + + const groupRatio = other.group_ratio; + if (groupRatio !== undefined && groupRatio !== null) { + segments.push({ + text: `${t('分组')} ${formatRatio(groupRatio)}x`, + tone: 'secondary', + }); + } + + if (other.crossed_tier) { + segments.push({ + text: `${t('跨阶梯')}: ${t('是')}`, + tone: 'secondary', + }); + } + + if (other.actual_quota_after_group !== undefined) { + segments.push({ + text: `${t('实际额度')}: ${other.actual_quota_after_group}`, + tone: 'secondary', + }); + } + + return { segments }; +} + function getUsageLogDetailSummary(record, text, billingDisplayMode, t) { const other = getLogOther(record.other); @@ -414,6 +451,10 @@ function getUsageLogDetailSummary(record, text, billingDisplayMode, t) { }; } + if (other?.billing_mode === 'tiered_expr') { + return buildTieredBillingSegments(other, t); + } + return { segments: other?.claude ? renderModelPriceSimple( diff --git a/web/src/hooks/usage-logs/useUsageLogsData.jsx b/web/src/hooks/usage-logs/useUsageLogsData.jsx index ebe1b882..84019a22 100644 --- a/web/src/hooks/usage-logs/useUsageLogsData.jsx +++ b/web/src/hooks/usage-logs/useUsageLogsData.jsx @@ -559,6 +559,66 @@ export const useLogsData = () => { value: other.reasoning_effort, }); } + if (other?.billing_mode === 'tiered_expr') { + expandDataLocal.push({ + key: t('计费方式'), + value: t('阶梯计费'), + }); + if (other?.group_ratio !== undefined) { + const gr = other.group_ratio; + expandDataLocal.push({ + key: t('分组倍率'), + value: typeof gr === 'number' ? gr.toFixed(4) : String(gr ?? '-'), + }); + } + if (other?.rule_version !== undefined) { + expandDataLocal.push({ + key: t('规则版本'), + value: String(other.rule_version), + }); + } + if (other?.estimated_env) { + expandDataLocal.push({ + key: t('预估环境'), + value: `prompt=${other.estimated_env.prompt_tokens ?? 0}, completion=${other.estimated_env.completion_tokens ?? 0}`, + }); + } + if (other?.actual_env) { + expandDataLocal.push({ + key: t('实际环境'), + value: `prompt=${other.actual_env.prompt_tokens ?? 0}, completion=${other.actual_env.completion_tokens ?? 0}`, + }); + } + if (other?.estimated_quota_after_group !== undefined) { + expandDataLocal.push({ + key: t('预估额度'), + value: String(other.estimated_quota_after_group), + }); + } + if (other?.actual_quota_after_group !== undefined) { + expandDataLocal.push({ + key: t('实际额度'), + value: String(other.actual_quota_after_group), + }); + } + expandDataLocal.push({ + key: t('跨阶梯'), + value: other?.crossed_tier ? t('是') : t('否'), + }); + if (Array.isArray(other?.breakdown) && other.breakdown.length > 0) { + const breakdownText = other.breakdown.map((item, idx) => + `[${idx}] ${item.token_type} | tokens=${item.tokens_in_tier} | cost=${item.unit_cost} | flat=${item.flat_fee} | sub=${item.subtotal}` + ).join('\n'); + expandDataLocal.push({ + key: t('计费明细'), + value: ( +
+ {breakdownText} +
+ ), + }); + } + } } if (logs[i].type === 6) { if (other?.task_id) { diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index e7213cd3..583d525c 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -3256,6 +3256,20 @@ "补全价格": "Completion Price", "缓存读取价格": "Input Cache Read Price", "缓存创建价格": "Input Cache Creation Price", + "缓存创建价格-5分钟": "Cache Creation Price (5-min)", + "缓存创建价格-1小时": "Cache Creation Price (1-hour)", + "缓存创建价格(5分钟)": "Cache Creation Price (5-min)", + "缓存创建价格(1小时)": "Cache Creation Price (1-hour)", + "分时缓存 (Claude)": "Timed Cache (Claude)", + "通用缓存": "Generic Cache", + "缓存读取": "Cache Read", + "缓存创建": "Cache Creation", + "缓存创建-5分钟": "Cache Creation (5-min)", + "缓存创建-1小时": "Cache Creation (1-hour)", + "缓存读取 Token (cr)": "Cache Read Tokens (cr)", + "缓存创建 Token (cc)": "Cache Creation Tokens (cc)", + "缓存创建-5分钟 (cc5)": "Cache Creation-5min (cc5)", + "缓存创建-1小时 (cc1h)": "Cache Creation-1hour (cc1h)", "图片输入价格": "Image Input Price", "音频输入价格": "Audio Input Price", "音频输入价格:{{symbol}}{{price}} / 1M tokens": "Audio input price: {{symbol}}{{price}} / 1M tokens", @@ -3309,6 +3323,104 @@ "输入价格:{{symbol}}{{price}} / 1M tokens": "Input Price: {{symbol}}{{price}} / 1M tokens", "输出价格 {{symbol}}{{price}} / 1M tokens": "Output Price {{symbol}}{{price}} / 1M tokens", "输出价格:{{symbol}}{{price}} / 1M tokens": "Output Price: {{symbol}}{{price}} / 1M tokens", - "输出价格:{{symbol}}{{total}} / 1M tokens": "Output Price: {{symbol}}{{total}} / 1M tokens" + "输出价格:{{symbol}}{{total}} / 1M tokens": "Output Price: {{symbol}}{{total}} / 1M tokens", + "阶梯计费": "Tiered Billing", + "输入 Tokens 阶梯": "Input Token Tiers", + "输出 Tokens 阶梯": "Output Token Tiers", + "固定阶梯": "Fixed Tier", + "累进阶梯": "Graduated Tier", + "上限": "Up To", + "单价": "Unit Cost", + "固定费": "Flat Fee", + "Expr 预览": "Expression Preview", + "Token 估算器": "Token Estimator", + "预计费用": "Estimated Cost", + "原始额度": "Raw Quota", + "添加阶梯": "Add Tier", + "无限": "Unlimited", + "输入 Token 定价": "Input Token Pricing", + "输出 Token 定价": "Output Token Pricing", + "统一定价": "Flat Rate", + "阶梯累进": "Graduated", + "根据总用量落在哪个档位,所有 Token 都按该档价格计费": "All tokens are charged at the rate of the tier your total usage falls into", + "用量分段计价,每一段各自按对应档位价格计费(类似电费阶梯)": "Usage is charged in segments — each segment at its own tier rate (like utility billing)", + "Token 用量范围": "Token Usage Range", + "所有 Token": "All Tokens", + "前 {{count}} 个": "First {{count}}", + "超过 {{count}} 个": "Over {{count}}", + "第 {{n}} 档": "Tier {{n}}", + "最高档": "Highest Tier", + "此档上限(Token 数)": "Tier Limit (Token Count)", + "每百万 Token 价格": "Price per 1M Tokens", + "进入此档额外收费": "Tier Entry Fee", + "可选,用量达到此档时加收的固定费用": "Optional fixed fee charged when usage reaches this tier", + "添加更多档位": "Add More Tiers", + "输入 Token 数": "Input Tokens", + "输出 Token 数": "Output Tokens", + "输入 Token 数量,查看按当前阶梯配置的预计费用。": "Enter token counts to see the estimated cost with the current tier configuration.", + "开发者": "Developer", + "阶梯计费详情": "Tiered Billing Details", + "预估环境": "Estimated Env", + "实际环境": "Actual Env", + "预估额度": "Estimated Quota", + "实际额度": "Actual Quota", + "跨阶梯": "Crossed Tier", + "是": "Yes", + "否": "No", + "计费明细": "Billing Breakdown", + "阶梯序号": "Tier #", + "Token 类型": "Token Type", + "阶梯内 Token 数": "Tokens in Tier", + "小计": "Subtotal", + "输入": "Input", + "输出": "Output", + "阶梯配置摘要": "Tier Config Summary", + "输入阶梯": "Input Tiers", + "档位名称": "Tier Name", + "用量范围": "Usage Range", + "输入 Token": "Input Token", + "输出 Token": "Output Token", + "阶梯判断依据": "Tier Criterion", + "根据哪个维度的 Token 数量决定落在哪一档": "Determines which tier to apply based on this dimension's token count", + "输入 Token 数 (p)": "Input Tokens (p)", + "输出 Token 数 (c)": "Output Tokens (c)", + "变量": "Variables", + "函数": "Functions", + "输入计费表达式...": "Enter billing expression...", + "表达式编辑": "Expression Editor", + "表达式错误": "Expression Error", + "命中档位": "Matched Tier", + "档": "tier(s)", + "输入 Token 数量,查看按当前配置的预计费用。": "Enter token counts to see the estimated cost.", + "输入 Token 数量,查看按当前配置的预计费用(不含分组倍率)。": "Enter token counts to see the estimated cost (before group ratio).", + "条件": "Condition", + "添加条件": "Add Condition", + "无条件(兜底档)": "No condition (fallback)", + "兜底档": "Fallback", + "预设模板": "Presets", + "每个档位可设置 0~2 个条件(对 p 和 c),最后一档为兜底档无需条件。": "Each tier can have 0-2 conditions (on p and c). The last tier is the fallback and needs no condition.", + "输出阶梯": "Output Tiers", + "阶": "tiers", + "规则版本": "Rule Version", + "时间条件": "Time condition", + "小时": "Hour", + "分钟": "Minute", + "星期": "Weekday", + "月份": "Month", + "日期": "Day", + "时区": "Timezone", + "跨夜范围": "Cross-midnight range", + "添加时间规则": "Add time rule", + "起": "From", + "止": "To", + "值": "Value", + "添加条件组": "Add condition group", + "添加条件": "Add condition", + "添加时间条件": "Add time condition", + "同时满足": "all must match", + "新年促销": "New Year promo", + "第 {{n}} 组": "Group {{n}}", + "0=周日 1=周一 2=周二 3=周三 4=周四 5=周五 6=周六": "0=Sun 1=Mon 2=Tue 3=Wed 4=Thu 5=Fri 6=Sat", + "1=一月 ... 12=十二月": "1=Jan ... 12=Dec" } } diff --git a/web/src/i18n/locales/zh-CN.json b/web/src/i18n/locales/zh-CN.json index 02681108..ccbdee91 100644 --- a/web/src/i18n/locales/zh-CN.json +++ b/web/src/i18n/locales/zh-CN.json @@ -2858,6 +2858,20 @@ "补全价格": "补全价格", "缓存读取价格": "缓存读取价格", "缓存创建价格": "缓存创建价格", + "缓存创建价格-5分钟": "缓存创建价格-5分钟", + "缓存创建价格-1小时": "缓存创建价格-1小时", + "缓存创建价格(5分钟)": "缓存创建价格(5分钟)", + "缓存创建价格(1小时)": "缓存创建价格(1小时)", + "分时缓存 (Claude)": "分时缓存 (Claude)", + "通用缓存": "通用缓存", + "缓存读取": "缓存读取", + "缓存创建": "缓存创建", + "缓存创建-5分钟": "缓存创建-5分钟", + "缓存创建-1小时": "缓存创建-1小时", + "缓存读取 Token (cr)": "缓存读取 Token (cr)", + "缓存创建 Token (cc)": "缓存创建 Token (cc)", + "缓存创建-5分钟 (cc5)": "缓存创建-5分钟 (cc5)", + "缓存创建-1小时 (cc1h)": "缓存创建-1小时 (cc1h)", "图片输入价格": "图片输入价格", "音频输入价格": "音频输入价格", "音频补全价格": "音频补全价格", @@ -2938,6 +2952,102 @@ "输入价格:{{symbol}}{{price}} / 1M tokens": "输入价格:{{symbol}}{{price}} / 1M tokens", "输出价格 {{symbol}}{{price}} / 1M tokens": "输出价格 {{symbol}}{{price}} / 1M tokens", "输出价格:{{symbol}}{{price}} / 1M tokens": "输出价格:{{symbol}}{{price}} / 1M tokens", - "输出价格:{{symbol}}{{total}} / 1M tokens": "输出价格:{{symbol}}{{total}} / 1M tokens" + "输出价格:{{symbol}}{{total}} / 1M tokens": "输出价格:{{symbol}}{{total}} / 1M tokens", + "阶梯计费": "阶梯计费", + "输入 Tokens 阶梯": "输入 Tokens 阶梯", + "输出 Tokens 阶梯": "输出 Tokens 阶梯", + "固定阶梯": "固定阶梯", + "累进阶梯": "累进阶梯", + "上限": "上限", + "单价": "单价", + "固定费": "固定费", + "Expr 预览": "Expr 预览", + "Token 估算器": "Token 估算器", + "预计费用": "预计费用", + "添加阶梯": "添加阶梯", + "无限": "无限", + "输入 Token 定价": "输入 Token 定价", + "输出 Token 定价": "输出 Token 定价", + "统一定价": "统一定价", + "阶梯累进": "阶梯累进", + "根据总用量落在哪个档位,所有 Token 都按该档价格计费": "根据总用量落在哪个档位,所有 Token 都按该档价格计费", + "用量分段计价,每一段各自按对应档位价格计费(类似电费阶梯)": "用量分段计价,每一段各自按对应档位价格计费(类似电费阶梯)", + "Token 用量范围": "Token 用量范围", + "所有 Token": "所有 Token", + "前 {{count}} 个": "前 {{count}} 个", + "超过 {{count}} 个": "超过 {{count}} 个", + "第 {{n}} 档": "第 {{n}} 档", + "最高档": "最高档", + "此档上限(Token 数)": "此档上限(Token 数)", + "每百万 Token 价格": "每百万 Token 价格", + "进入此档额外收费": "进入此档额外收费", + "可选,用量达到此档时加收的固定费用": "可选,用量达到此档时加收的固定费用", + "添加更多档位": "添加更多档位", + "输入 Token 数": "输入 Token 数", + "输出 Token 数": "输出 Token 数", + "输入 Token 数量,查看按当前阶梯配置的预计费用。": "输入 Token 数量,查看按当前阶梯配置的预计费用。", + "开发者": "开发者", + "阶梯计费详情": "阶梯计费详情", + "预估环境": "预估环境", + "实际环境": "实际环境", + "预估额度": "预估额度", + "实际额度": "实际额度", + "跨阶梯": "跨阶梯", + "是": "是", + "否": "否", + "计费明细": "计费明细", + "阶梯序号": "阶梯序号", + "Token 类型": "Token 类型", + "阶梯内 Token 数": "阶梯内 Token 数", + "小计": "小计", + "输入": "输入", + "档位标签": "档位标签", + "用量范围": "用量范围", + "输入 Token": "输入 Token", + "输出 Token": "输出 Token", + "阶梯判断依据": "阶梯判断依据", + "根据哪个维度的 Token 数量决定落在哪一档": "根据哪个维度的 Token 数量决定落在哪一档", + "输入 Token 数 (p)": "输入 Token 数 (p)", + "输出 Token 数 (c)": "输出 Token 数 (c)", + "变量": "变量", + "函数": "函数", + "输入计费表达式...": "输入计费表达式...", + "表达式编辑": "表达式编辑", + "表达式错误": "表达式错误", + "命中档位": "命中档位", + "档": "档", + "输入 Token 数量,查看按当前配置的预计费用。": "输入 Token 数量,查看按当前配置的预计费用。", + "条件": "条件", + "添加条件": "添加条件", + "无条件(兜底档)": "无条件(兜底档)", + "兜底档": "兜底档", + "预设模板": "预设模板", + "每个档位可设置 0~2 个条件(对 p 和 c),最后一档为兜底档无需条件。": "每个档位可设置 0~2 个条件(对 p 和 c),最后一档为兜底档无需条件。", + "输出": "输出", + "阶梯配置摘要": "阶梯配置摘要", + "输入阶梯": "输入阶梯", + "输出阶梯": "输出阶梯", + "阶": "阶", + "规则版本": "规则版本", + "时间条件": "时间条件", + "小时": "小时", + "分钟": "分钟", + "星期": "星期", + "月份": "月份", + "日期": "日期", + "时区": "时区", + "跨夜范围": "跨夜范围", + "添加时间规则": "添加时间规则", + "起": "起", + "止": "止", + "值": "值", + "添加条件组": "添加条件组", + "添加条件": "添加条件", + "添加时间条件": "添加时间条件", + "同时满足": "同时满足", + "新年促销": "新年促销", + "第 {{n}} 组": "第 {{n}} 组", + "0=周日 1=周一 2=周二 3=周三 4=周四 5=周五 6=周六": "0=周日 1=周一 2=周二 3=周三 4=周四 5=周五 6=周六", + "1=一月 ... 12=十二月": "1=一月 ... 12=十二月" } } diff --git a/web/src/pages/Setting/Ratio/components/ModelPricingEditor.jsx b/web/src/pages/Setting/Ratio/components/ModelPricingEditor.jsx index 5028a3ff..0a6bb45b 100644 --- a/web/src/pages/Setting/Ratio/components/ModelPricingEditor.jsx +++ b/web/src/pages/Setting/Ratio/components/ModelPricingEditor.jsx @@ -17,7 +17,7 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React, { useMemo, useState } from 'react'; +import React, { useCallback, useMemo, useState } from 'react'; import { Banner, Button, @@ -49,6 +49,7 @@ import { useModelPricingEditorState, } from '../hooks/useModelPricingEditorState'; import { useIsMobile } from '../../../../hooks/common/useIsMobile'; +import TieredPricingEditor from './TieredPricingEditor'; const { Text } = Typography; const EMPTY_CANDIDATE_MODEL_NAMES = []; @@ -123,6 +124,8 @@ export default function ModelPricingEditor({ handleOptionalFieldToggle, handleNumericFieldChange, handleBillingModeChange, + handleBillingExprChange, + handleRequestRuleExprChange, handleSubmit, addModel, deleteModel, @@ -135,6 +138,15 @@ export default function ModelPricingEditor({ filterMode, }); + const getExprModeLabel = useCallback((model) => { + if (model?.billingMode !== 'tiered_expr') { + return ''; + } + return (model.billingExpr || '').includes('tier(') + ? t('阶梯计费') + : t('表达式计费'); + }, [t]); + const columns = useMemo( () => [ { @@ -175,10 +187,20 @@ export default function ModelPricingEditor({ dataIndex: 'billingMode', key: 'billingMode', render: (_, record) => ( - + {record.billingMode === 'per-request' ? t('按次计费') - : t('按量计费')} + : record.billingMode === 'tiered_expr' + ? getExprModeLabel(record) + : t('按量计费')} ), }, @@ -208,6 +230,7 @@ export default function ModelPricingEditor({ [ allowDeleteModel, deleteModel, + getExprModeLabel, selectedModelName, selectedModelNames, setSelectedModelName, @@ -353,10 +376,20 @@ export default function ModelPricingEditor({ title={selectedModel ? selectedModel.name : t('模型计费编辑器')} headerExtraContent={ selectedModel ? ( - + {selectedModel.billingMode === 'per-request' ? t('按次计费') - : t('按量计费')} + : selectedModel.billingMode === 'tiered_expr' + ? getExprModeLabel(selectedModel) + : t('按量计费')} ) : null } @@ -381,10 +414,11 @@ export default function ModelPricingEditor({ > {t('按量计费')} {t('按次计费')} + {t('表达式/阶梯计费')}
{t( - '这个界面默认按价格填写,保存时会自动换算回后端需要的倍率 JSON。', + '普通按量/按次直接填价格就行;如果价格要跟请求参数或请求头联动,请切到表达式/阶梯计费。', )}
@@ -415,6 +449,14 @@ export default function ModelPricingEditor({ onChange={(value) => handleNumericFieldChange('fixedPrice', value)} extraText={t('适合 MJ / 任务类等按次收费模型。')} /> + ) : selectedModel.billingMode === 'tiered_expr' ? ( + ) : ( <> . + +For commercial licensing, please contact support@quantumnous.com +*/ +import React, { useCallback, useEffect, useMemo, useState } from 'react'; +import { + Banner, + Button, + Card, + Collapsible, + Input, + InputNumber, + Radio, + RadioGroup, + Select, + Tag, + TextArea, + Typography, +} from '@douyinfe/semi-ui'; +import { IconDelete, IconPlus } from '@douyinfe/semi-icons'; +import { renderQuota } from '../../../../helpers/render'; +import { + createEmptyCondition, + createEmptyTimeCondition, + createEmptyRuleGroup, + createEmptyTimeRuleGroup, + getRequestRuleMatchOptions, + normalizeCondition, + tryParseRequestRuleExpr, + buildRequestRuleExpr, + combineBillingExpr, + splitBillingExprAndRequestRules, + MATCH_EQ, + MATCH_EXISTS, + MATCH_CONTAINS, + MATCH_RANGE, + SOURCE_HEADER, + SOURCE_PARAM, + SOURCE_TIME, + TIME_FUNCS, + COMMON_TIMEZONES, +} from './requestRuleExpr'; + +const { Text } = Typography; + +const PRICE_SUFFIX = '$/1M tokens'; + +function unitCostToPrice(uc) { + return (Number(uc) || 0) * 2; +} +function priceToUnitCost(price) { + return (Number(price) || 0) / 2; +} + +const OPS = ['<', '<=', '>', '>=']; +const VAR_OPTIONS = [ + { value: 'p', label: 'p (输入)' }, + { value: 'c', label: 'c (输出)' }, +]; + +const CACHE_MODE_TIMED = 'timed'; +const CACHE_MODE_GENERIC = 'generic'; + +function formatTokenHint(n) { + if (n == null || n === '' || Number.isNaN(Number(n))) return ''; + const v = Number(n); + if (v === 0) return '= 0'; + if (v >= 1000000) return `= ${(v / 1000000).toLocaleString()}M tokens`; + if (v >= 1000) return `= ${(v / 1000).toLocaleString()}K tokens`; + return `= ${v.toLocaleString()} tokens`; +} + +// --------------------------------------------------------------------------- +// Expr generation from visual config (multi-condition) +// --------------------------------------------------------------------------- + +function buildConditionStr(conditions) { + if (!conditions || conditions.length === 0) return ''; + return conditions + .filter((c) => c.var && c.op && c.value != null && c.value !== '') + .map((c) => `${c.var} ${c.op} ${c.value}`) + .join(' && '); +} + +// CACHE_VAR_MAP maps tier data fields to Expr variable names +const CACHE_VAR_MAP = [ + { field: 'cache_read_unit_cost', exprVar: 'cr' }, + { field: 'cache_create_unit_cost', exprVar: 'cc' }, + { field: 'cache_create_1h_unit_cost', exprVar: 'cc1h' }, +]; + +function getTierCacheMode(tier) { + if (tier?.cache_mode === CACHE_MODE_TIMED) { + return CACHE_MODE_TIMED; + } + if (tier?.cache_mode === CACHE_MODE_GENERIC) { + return CACHE_MODE_GENERIC; + } + return Number(tier?.cache_create_1h_unit_cost) > 0 + ? CACHE_MODE_TIMED + : CACHE_MODE_GENERIC; +} + +function normalizeVisualTier(tier = {}) { + return { + ...tier, + conditions: Array.isArray(tier.conditions) ? tier.conditions : [], + cache_mode: getTierCacheMode(tier), + }; +} + +function createDefaultVisualConfig() { + return { + tiers: [ + normalizeVisualTier({ + conditions: [], + input_unit_cost: 0, + output_unit_cost: 0, + label: '默认', + cache_mode: CACHE_MODE_GENERIC, + }), + ], + }; +} + +function normalizeVisualConfig(config) { + if (!config || !Array.isArray(config.tiers) || config.tiers.length === 0) { + return createDefaultVisualConfig(); + } + return { + ...config, + tiers: config.tiers.map((tier) => normalizeVisualTier(tier)), + }; +} + +function buildTierBodyExpr(tier) { + const parts = []; + const ic = Number(tier.input_unit_cost) || 0; + const oc = Number(tier.output_unit_cost) || 0; + parts.push(`p * ${ic}`); + parts.push(`c * ${oc}`); + for (const cv of CACHE_VAR_MAP) { + const v = Number(tier[cv.field]) || 0; + if (v !== 0) parts.push(`${cv.exprVar} * ${v}`); + } + return parts.join(' + '); +} + +function generateExprFromVisualConfig(config) { + if (!config || !config.tiers || config.tiers.length === 0) + return 'p * 0 + c * 0'; + const tiers = config.tiers; + + if (tiers.length === 1) { + const t = tiers[0]; + const label = t.label || 'default'; + return `tier("${label}", ${buildTierBodyExpr(t)})`; + } + + const parts = []; + for (let i = 0; i < tiers.length; i++) { + const t = tiers[i]; + const label = t.label || `第${i + 1}档`; + const body = `tier("${label}", ${buildTierBodyExpr(t)})`; + const cond = buildConditionStr(t.conditions); + + if (i < tiers.length - 1 && cond) { + parts.push(`${cond} ? ${body}`); + } else { + parts.push(body); + } + } + return parts.join(' : '); +} + +// --------------------------------------------------------------------------- +// Reverse-parse an Expr string back into visual config +// --------------------------------------------------------------------------- + +function tryParseVisualConfig(exprStr) { + if (!exprStr) return null; + try { + const cacheVarNames = CACHE_VAR_MAP.map((cv) => cv.exprVar); + const optCacheStr = cacheVarNames + .map((v) => `(?:\\s*\\+\\s*${v}\\s*\\*\\s*([\\d.eE+-]+))?`) + .join(''); + + // Body pattern: p * X + c * Y [+ cr * A] [+ cc * B] [+ cc1h * C] + const bodyPat = `p\\s*\\*\\s*([\\d.eE+-]+)\\s*\\+\\s*c\\s*\\*\\s*([\\d.eE+-]+)${optCacheStr}`; + + // Single-tier: tier("label", body) + const singleRe = new RegExp(`^tier\\("([^"]*)",\\s*${bodyPat}\\)$`); + const simple = exprStr.match(singleRe); + if (simple) { + const tier = { + conditions: [], + input_unit_cost: Number(simple[2]), + output_unit_cost: Number(simple[3]), + label: simple[1], + }; + CACHE_VAR_MAP.forEach((cv, i) => { + const val = simple[4 + i]; + if (val != null) tier[cv.field] = Number(val); + }); + return normalizeVisualConfig({ tiers: [normalizeVisualTier(tier)] }); + } + + // Multi-tier: cond1 ? tier(body) : cond2 ? tier(body) : tier(body) + const condGroup = `((?:(?:p|c)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)(?:\\s*&&\\s*(?:p|c)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)*)`; + const tierRe = new RegExp( + `(?:${condGroup}\\s*\\?\\s*)?tier\\("([^"]*)",\\s*${bodyPat}\\)`, + 'g', + ); + const tiers = []; + let match; + while ((match = tierRe.exec(exprStr)) !== null) { + const condStr = match[1] || ''; + const conditions = []; + if (condStr) { + const condParts = condStr.split(/\s*&&\s*/); + for (const cp of condParts) { + const cm = cp.trim().match(/^(p|c)\s*(<|<=|>|>=)\s*([\d.eE+]+)$/); + if (cm) { + conditions.push({ var: cm[1], op: cm[2], value: Number(cm[3]) }); + } + } + } + const tier = { + conditions, + input_unit_cost: Number(match[3]), + output_unit_cost: Number(match[4]), + label: match[2], + }; + CACHE_VAR_MAP.forEach((cv, i) => { + const val = match[5 + i]; + if (val != null) tier[cv.field] = Number(val); + }); + tiers.push(normalizeVisualTier(tier)); + } + if (tiers.length === 0) return null; + + const cfg = normalizeVisualConfig({ tiers }); + const regenerated = generateExprFromVisualConfig(cfg); + if (regenerated.replace(/\s+/g, '') !== exprStr.replace(/\s+/g, '')) + return null; + return cfg; + } catch { + return null; + } +} + +// --------------------------------------------------------------------------- +// Condition editor row +// --------------------------------------------------------------------------- + +function ConditionRow({ cond, onChange, onRemove, t }) { + const hint = formatTokenHint(cond.value); + return ( +
+
+ + + onChange({ ...cond, value: val })} + style={{ flex: 1, minWidth: 100 }} + /> +
+ {hint ? ( + + {hint} + + ) : null} +
+ ); +} + +// --------------------------------------------------------------------------- +// Price input that preserves intermediate text like "7." or "0.5" +// --------------------------------------------------------------------------- + +function PriceInput({ unitCost, field, index, onUpdate, placeholder }) { + const priceFromModel = unitCostToPrice(unitCost); + const [text, setText] = useState(priceFromModel === 0 ? '' : String(priceFromModel)); + + useEffect(() => { + const current = Number(text); + if (text === '' && priceFromModel === 0) return; + if (!Number.isNaN(current) && current === priceFromModel) return; + setText(priceFromModel === 0 ? '' : String(priceFromModel)); + }, [priceFromModel]); + + const handleChange = (val) => { + setText(val); + if (val === '') { + onUpdate(index, field, 0); + return; + } + const num = Number(val); + if (!Number.isNaN(num)) { + onUpdate(index, field, priceToUnitCost(num)); + } + }; + + return ( + + ); +} + +// --------------------------------------------------------------------------- +// Extended price block (cache fields) — collapsible per tier, with mode switch +// --------------------------------------------------------------------------- + +const CACHE_FIELDS_TIMED = [ + { field: 'cache_read_unit_cost', labelKey: '缓存读取价格' }, + { field: 'cache_create_unit_cost', labelKey: '缓存创建价格(5分钟)' }, + { field: 'cache_create_1h_unit_cost', labelKey: '缓存创建价格(1小时)' }, +]; + +const CACHE_FIELDS_GENERIC = [ + { field: 'cache_read_unit_cost', labelKey: '缓存读取价格' }, + { field: 'cache_create_unit_cost', labelKey: '缓存创建价格' }, +]; + +function ExtendedPriceBlock({ tier, index, onUpdate, t }) { + const hasAny = [...CACHE_FIELDS_TIMED].some( + (f) => Number(tier[f.field]) > 0, + ); + const [expanded, setExpanded] = useState(hasAny); + const cacheMode = getTierCacheMode(tier); + + const handleCacheModeChange = (e) => { + const mode = e.target.value; + const patch = { cache_mode: mode }; + if (mode === CACHE_MODE_GENERIC) { + patch.cache_create_1h_unit_cost = 0; + } + onUpdate(index, patch); + }; + + const activeFields = + cacheMode === CACHE_MODE_TIMED ? CACHE_FIELDS_TIMED : CACHE_FIELDS_GENERIC; + + return ( +
+ + +
+
+ {t('这些价格都是可选项,不填也可以。')} +
+
+ + {t('通用缓存')} + {t('分时缓存 (Claude)')} + +
+
+ {activeFields.map((cf) => ( +
+ + {t(cf.labelKey)} + + +
+ ))} +
+
+
+
+ ); +} + +// --------------------------------------------------------------------------- +// Visual Tier Card (multi-condition) +// --------------------------------------------------------------------------- + +function VisualTierCard({ tier, index, isLast, isOnly, onUpdate, onRemove, t }) { + const conditions = tier.conditions || []; + + const varLabel = { p: t('输入'), c: t('输出') }; + const condSummary = useMemo(() => { + if (conditions.length === 0) return t('无条件(兜底档)'); + return conditions + .filter((c) => c.var && c.op && c.value != null) + .map((c) => `${varLabel[c.var] || c.var} ${c.op} ${formatTokenHint(c.value)}`) + .join(' && '); + }, [conditions, t]); + + const updateCondition = (ci, newCond) => { + const next = conditions.map((c, i) => (i === ci ? newCond : c)); + onUpdate(index, 'conditions', next); + }; + + const removeCondition = (ci) => { + onUpdate( + index, + 'conditions', + conditions.filter((_, i) => i !== ci), + ); + }; + + const addCondition = () => { + if (conditions.length >= 2) return; + const usedVars = conditions.map((c) => c.var); + const nextVar = usedVars.includes('p') ? 'c' : 'p'; + onUpdate(index, 'conditions', [ + ...conditions, + { var: nextVar, op: '<', value: 200000 }, + ]); + }; + + return ( +
+
+
+ + {t('第 {{n}} 档', { n: index + 1 })} + + {isLast && !isOnly ? ( + + {t('兜底档')} + + ) : null} +
+ {!isOnly ? ( +
+ + {/* Tier label */} +
+ + {t('档位名称')} + + onUpdate(index, 'label', val)} + style={{ width: '100%', marginTop: 2 }} + /> +
+ + {/* Conditions */} + {!isLast || isOnly ? ( +
+ + {t('条件')} + + {conditions.map((cond, ci) => ( + updateCondition(ci, nc)} + onRemove={() => removeCondition(ci)} + t={t} + /> + ))} + {conditions.length < 2 && ( + + )} +
+ ) : ( +
+ + {condSummary} + +
+ )} + + {/* Prices */} +
+
+ + {t('输入价格')} + + +
+
+ + {t('输出价格')} + + +
+
+ + {/* Extended prices (cache) — collapsible */} + +
+ ); +} + +// --------------------------------------------------------------------------- +// Visual editor +// --------------------------------------------------------------------------- + +function VisualEditor({ visualConfig, onChange, t }) { + const config = normalizeVisualConfig(visualConfig); + const tiers = config.tiers || []; + + const updateTier = (index, field, value) => { + const patch = + typeof field === 'string' ? { [field]: value } : { ...field }; + const next = tiers.map((tier, i) => + i === index ? normalizeVisualTier({ ...tier, ...patch }) : tier, + ); + onChange({ ...config, tiers: next }); + }; + + const addTier = () => { + const newTiers = [...tiers]; + if ( + newTiers.length > 0 && + (!newTiers[newTiers.length - 1].conditions || + newTiers[newTiers.length - 1].conditions.length === 0) + ) { + newTiers[newTiers.length - 1] = { + ...newTiers[newTiers.length - 1], + conditions: [{ var: 'p', op: '<', value: 200000 }], + }; + } + newTiers.push({ + conditions: [], + input_unit_cost: 0, + output_unit_cost: 0, + label: `第${newTiers.length + 1}档`, + cache_mode: CACHE_MODE_GENERIC, + }); + onChange({ ...config, tiers: newTiers }); + }; + + const removeTier = (index) => { + if (tiers.length <= 1) return; + const next = tiers.filter((_, i) => i !== index); + if (next.length > 0) { + next[next.length - 1] = { + ...next[next.length - 1], + conditions: [], + }; + } + onChange({ ...config, tiers: next }); + }; + + return ( +
+ + + {tiers.map((tier, index) => ( + + ))} + +
+ ); +} + +// --------------------------------------------------------------------------- +// Raw Expr editor with preset templates +// --------------------------------------------------------------------------- + +const PRESETS = [ + { + key: 'claude-opus', + label: 'Claude Opus 4.6', + expr: 'tier("default", p * 2.5 + c * 12.5 + cr * 0.25 + cc * 3.125 + cc1h * 5)', + }, + { + key: 'claude-opus-fast', + label: 'Claude Opus 4.6 Fast', + expr: 'tier("default", p * 2.5 + c * 12.5 + cr * 0.25 + cc * 3.125 + cc1h * 5)', + requestRules: [ + { conditions: [{ source: SOURCE_HEADER, path: 'anthropic-beta', mode: MATCH_CONTAINS, value: 'fast-mode-2026-02-01' }], multiplier: '6' }, + ], + }, + { + key: 'claude-sonnet', + label: 'Claude Sonnet 4.5', + expr: 'p <= 200000 ? tier("standard", p * 1.5 + c * 7.5 + cr * 0.15 + cc * 1.875 + cc1h * 3) : tier("long_context", p * 3 + c * 11.25 + cr * 0.3 + cc * 3.75 + cc1h * 6)', + }, + { + key: 'glm-4.5-air', + label: 'GLM-4.5-Air', + expr: 'p < 32000 && c < 200 ? tier("short_output", p * 0.4 + c * 1 + cr * 0.08) : p < 32000 && c >= 200 ? tier("long_output", p * 0.4 + c * 3 + cr * 0.08) : tier("mid_context", p * 0.6 + c * 4 + cr * 0.12)', + }, + { + key: 'gpt-5.4-fast', + label: 'GPT-5.4 Fast', + expr: 'tier("default", p * 1.25 + c * 5 + cr * 0.125)', + requestRules: [ + { conditions: [{ source: SOURCE_PARAM, path: 'service_tier', mode: MATCH_EQ, value: 'fast' }], multiplier: '2' }, + ], + }, + { + key: 'flat', + label: 'Flat', + expr: 'tier("default", p * 1 + c * 2)', + }, + { + key: 'night-discount', + label: '夜间半价', + expr: 'tier("default", p * 1.5 + c * 7.5)', + requestRules: [ + { conditions: [{ source: SOURCE_TIME, timeFunc: 'hour', timezone: 'Asia/Shanghai', mode: MATCH_RANGE, rangeStart: '21', rangeEnd: '6' }], multiplier: '0.5' }, + ], + }, + { + key: 'weekend-discount', + label: '周末8折', + expr: 'tier("default", p * 1.5 + c * 7.5)', + requestRules: [ + { conditions: [{ source: SOURCE_TIME, timeFunc: 'weekday', timezone: 'Asia/Shanghai', mode: MATCH_EQ, value: '0' }], multiplier: '0.8' }, + { conditions: [{ source: SOURCE_TIME, timeFunc: 'weekday', timezone: 'Asia/Shanghai', mode: MATCH_EQ, value: '6' }], multiplier: '0.8' }, + ], + }, + { + key: 'new-year-promo', + label: '新年促销', + expr: 'tier("default", p * 1.5 + c * 7.5)', + requestRules: [ + { conditions: [ + { source: SOURCE_TIME, timeFunc: 'month', timezone: 'Asia/Shanghai', mode: MATCH_EQ, value: '1' }, + { source: SOURCE_TIME, timeFunc: 'day', timezone: 'Asia/Shanghai', mode: MATCH_EQ, value: '1' }, + ], multiplier: '0.5' }, + ], + }, +]; + +function RawExprEditor({ exprString, onChange, t }) { + return ( +
+ +
+ {t('变量')}: p ({t('输入 Token')}), c ( + {t('输出 Token')}), cr ({t('缓存读取')}),{' '} + cc ({t('缓存创建')}),{' '} + cc1h ({t('缓存创建-1小时')}) +
+
+ {t('函数')}: tier(name, value),{' '} + max(a, b), min(a, b),{' '} + ceil(x), floor(x),{' '} + abs(x), header(name),{' '} + param(path), has(source, text) +
+
+ {t('也支持更好懂的别名')}: prompt_tokens,{' '} + completion_tokens, cache_read_tokens,{' '} + cache_create_tokens,{' '} + cache_create_1h_tokens +
+
+ } + style={{ marginBottom: 12 }} + /> + +