feat: implement tiered billing expression evaluation and related functionality

- Added support for tiered billing expressions in the billing system.
- Introduced new types and functions for handling billing expressions, including caching and execution.
- Updated existing billing logic to accommodate tiered billing scenarios.
- Enhanced request handling to support incoming billing expression requests.
- Added tests for tiered billing functionality to ensure correctness.
This commit is contained in:
CaIon
2026-03-16 16:00:22 +08:00
parent a4fd2246ba
commit 91ed4e196a
34 changed files with 4797 additions and 26 deletions
+3
View File
@@ -29,3 +29,6 @@ data/
.gomodcache/
.gocache-temp
.gopath
token_estimator_test.go
skills-lock.json
+1
View File
@@ -75,6 +75,7 @@ require (
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/expr-lang/expr v1.17.8 // indirect
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
+2
View File
@@ -68,6 +68,8 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/expr-lang/expr v1.17.8 h1:W1loDTT+0PQf5YteHSTpju2qfUfNoBt4yw9+wOEU9VM=
github.com/expr-lang/expr v1.17.8/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
+7
View File
@@ -7,6 +7,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/billing_setting"
"github.com/QuantumNous/new-api/setting/config"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/performance_setting"
@@ -121,6 +122,8 @@ func InitOptionMap() {
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString()
common.OptionMap["ModelBillingMode"] = billing_setting.BillingMode2JSONString()
common.OptionMap["ModelBillingExpr"] = billing_setting.BillingExpr2JSONString()
common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString()
common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
@@ -436,6 +439,10 @@ func updateOptionMap(key string, value string) (err error) {
err = ratio_setting.UpdateAudioRatioByJSONString(value)
case "AudioCompletionRatio":
err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value)
case "ModelBillingMode":
err = billing_setting.UpdateBillingModeByJSONString(value)
case "ModelBillingExpr":
err = billing_setting.UpdateBillingExprByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":
+933
View File
@@ -0,0 +1,933 @@
package billingexpr_test
import (
"math"
"math/rand"
"testing"
"github.com/QuantumNous/new-api/pkg/billingexpr"
)
// ---------------------------------------------------------------------------
// Claude-style: fixed tiers, input > 200K changes both input & output price
// ---------------------------------------------------------------------------
const claudeExpr = `p <= 200000 ? tier("standard", p * 1.5 + c * 7.5) : tier("long_context", p * 3.0 + c * 11.25)`
func TestClaude_StandardTier(t *testing.T) {
cost, trace, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{P: 100000, C: 5000})
if err != nil {
t.Fatal(err)
}
want := 100000*1.5 + 5000*7.5
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
if trace.MatchedTier != "standard" {
t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard")
}
}
func TestClaude_LongContextTier(t *testing.T) {
cost, trace, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{P: 300000, C: 10000})
if err != nil {
t.Fatal(err)
}
want := 300000*3.0 + 10000*11.25
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
if trace.MatchedTier != "long_context" {
t.Errorf("tier = %q, want %q", trace.MatchedTier, "long_context")
}
}
func TestClaude_BoundaryExact(t *testing.T) {
cost, trace, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{P: 200000, C: 1000})
if err != nil {
t.Fatal(err)
}
want := 200000*1.5 + 1000*7.5
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
if trace.MatchedTier != "standard" {
t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard")
}
}
// ---------------------------------------------------------------------------
// GLM-style: multi-condition tiers with both input and output dimensions
// ---------------------------------------------------------------------------
const glmExpr = `
(
p < 32000 && c < 200 ? tier("tier1_short", (p)*2 + c*8) :
p < 32000 && c >= 200 ? tier("tier2_long_output", (p)*3 + c*14) :
tier("tier3_long_input", (p)*4 + c*16)
) / 1000000
`
func TestGLM_Tier1(t *testing.T) {
cost, trace, err := billingexpr.RunExpr(glmExpr, billingexpr.TokenParams{P: 15000, C: 100})
if err != nil {
t.Fatal(err)
}
want := (15000.0*2 + 100.0*8) / 1000000
if math.Abs(cost-want) > 1e-10 {
t.Errorf("cost = %f, want %f", cost, want)
}
if trace.MatchedTier != "tier1_short" {
t.Errorf("tier = %q, want %q", trace.MatchedTier, "tier1_short")
}
}
func TestGLM_Tier2(t *testing.T) {
cost, trace, err := billingexpr.RunExpr(glmExpr, billingexpr.TokenParams{P: 15000, C: 500})
if err != nil {
t.Fatal(err)
}
want := (15000.0*3 + 500.0*14) / 1000000
if math.Abs(cost-want) > 1e-10 {
t.Errorf("cost = %f, want %f", cost, want)
}
if trace.MatchedTier != "tier2_long_output" {
t.Errorf("tier = %q, want %q", trace.MatchedTier, "tier2_long_output")
}
}
func TestGLM_Tier3(t *testing.T) {
cost, trace, err := billingexpr.RunExpr(glmExpr, billingexpr.TokenParams{P: 50000, C: 100})
if err != nil {
t.Fatal(err)
}
want := (50000.0*4 + 100.0*16) / 1000000
if math.Abs(cost-want) > 1e-10 {
t.Errorf("cost = %f, want %f", cost, want)
}
if trace.MatchedTier != "tier3_long_input" {
t.Errorf("tier = %q, want %q", trace.MatchedTier, "tier3_long_input")
}
}
// ---------------------------------------------------------------------------
// Simple flat-rate (no tier() call)
// ---------------------------------------------------------------------------
func TestSimpleExpr_NoTier(t *testing.T) {
cost, trace, err := billingexpr.RunExpr("p * 0.5 + c * 1.0", billingexpr.TokenParams{P: 1000, C: 500})
if err != nil {
t.Fatal(err)
}
want := 1000*0.5 + 500*1.0
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
if trace.MatchedTier != "" {
t.Errorf("tier should be empty, got %q", trace.MatchedTier)
}
}
// ---------------------------------------------------------------------------
// Math helper functions
// ---------------------------------------------------------------------------
func TestMathHelpers(t *testing.T) {
cost, _, err := billingexpr.RunExpr("max(p, c) * 0.5 + min(p, c) * 0.1", billingexpr.TokenParams{P: 300, C: 500})
if err != nil {
t.Fatal(err)
}
want := 500*0.5 + 300*0.1
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
}
func TestRequestProbeHelpers(t *testing.T) {
cost, _, err := billingexpr.RunExprWithRequest(
`prompt_tokens * 0.5 + completion_tokens * 1.0 * (param("service_tier") == "fast" ? 2 : 1)`,
billingexpr.TokenParams{P: 1000, C: 500},
billingexpr.RequestInput{
Body: []byte(`{"service_tier":"fast"}`),
},
)
if err != nil {
t.Fatal(err)
}
want := 1000*0.5 + 500*1.0*2
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
}
func TestHeaderProbeHelper(t *testing.T) {
cost, _, err := billingexpr.RunExprWithRequest(
`p * 0.5 + c * 1.0 * (has(header("anthropic-beta"), "fast-mode") ? 2 : 1)`,
billingexpr.TokenParams{P: 1000, C: 500},
billingexpr.RequestInput{
Headers: map[string]string{
"Anthropic-Beta": "fast-mode-2026-02-01",
},
},
)
if err != nil {
t.Fatal(err)
}
want := 1000*0.5 + 500*1.0*2
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
}
func TestParamProbeNestedBool(t *testing.T) {
cost, _, err := billingexpr.RunExprWithRequest(
`p * (param("stream_options.fast_mode") == true ? 1.5 : 1.0)`,
billingexpr.TokenParams{P: 100},
billingexpr.RequestInput{
Body: []byte(`{"stream_options":{"fast_mode":true}}`),
},
)
if err != nil {
t.Fatal(err)
}
want := 150.0
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
}
func TestParamProbeArrayLength(t *testing.T) {
cost, _, err := billingexpr.RunExprWithRequest(
`p * (param("messages.#") > 20 ? 1.2 : 1.0)`,
billingexpr.TokenParams{P: 100},
billingexpr.RequestInput{
Body: []byte(`{"messages":[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]}`),
},
)
if err != nil {
t.Fatal(err)
}
want := 120.0
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
}
func TestRequestProbeMissingFieldReturnsNil(t *testing.T) {
cost, _, err := billingexpr.RunExprWithRequest(
`param("missing.value") == nil ? 2 : 1`,
billingexpr.TokenParams{},
billingexpr.RequestInput{
Body: []byte(`{"service_tier":"standard"}`),
},
)
if err != nil {
t.Fatal(err)
}
if cost != 2 {
t.Errorf("cost = %f, want 2", cost)
}
}
func TestRequestProbeMultipleRulesMultiply(t *testing.T) {
cost, _, err := billingexpr.RunExprWithRequest(
`(param("service_tier") == "fast" ? 2 : 1) * (has(header("anthropic-beta"), "fast-mode-2026-02-01") ? 2.5 : 1)`,
billingexpr.TokenParams{},
billingexpr.RequestInput{
Headers: map[string]string{
"Anthropic-Beta": "fast-mode-2026-02-01",
},
Body: []byte(`{"service_tier":"fast"}`),
},
)
if err != nil {
t.Fatal(err)
}
if math.Abs(cost-5) > 1e-6 {
t.Errorf("cost = %f, want 5", cost)
}
}
func TestCeilFloor(t *testing.T) {
cost, _, err := billingexpr.RunExpr("ceil(p / 1000) * 0.5", billingexpr.TokenParams{P: 1500})
if err != nil {
t.Fatal(err)
}
want := math.Ceil(1500.0/1000) * 0.5
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
}
// ---------------------------------------------------------------------------
// Zero tokens
// ---------------------------------------------------------------------------
func TestZeroTokens(t *testing.T) {
cost, _, err := billingexpr.RunExpr(claudeExpr, billingexpr.TokenParams{})
if err != nil {
t.Fatal(err)
}
if cost != 0 {
t.Errorf("cost should be 0 for zero tokens, got %f", cost)
}
}
// ---------------------------------------------------------------------------
// Rounding
// ---------------------------------------------------------------------------
func TestQuotaRound(t *testing.T) {
tests := []struct {
in float64
want int
}{
{0, 0},
{0.4, 0},
{0.5, 1},
{0.6, 1},
{1.5, 2},
{-0.5, -1},
{-0.6, -1},
{999.4999, 999},
{999.5, 1000},
{1e9 + 0.5, 1e9 + 1},
}
for _, tt := range tests {
got := billingexpr.QuotaRound(tt.in)
if got != tt.want {
t.Errorf("QuotaRound(%f) = %d, want %d", tt.in, got, tt.want)
}
}
}
// ---------------------------------------------------------------------------
// Settlement
// ---------------------------------------------------------------------------
func TestComputeTieredQuota_Basic(t *testing.T) {
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: claudeExpr,
ExprHash: billingexpr.ExprHashString(claudeExpr),
GroupRatio: 1.0,
EstimatedPromptTokens: 100000,
EstimatedCompletionTokens: 5000,
EstimatedQuotaBeforeGroup: 100000*1.5 + 5000*7.5,
EstimatedQuotaAfterGroup: billingexpr.QuotaRound(100000*1.5 + 5000*7.5),
EstimatedTier: "standard",
}
result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 300000, C: 10000})
if err != nil {
t.Fatal(err)
}
wantBefore := 300000*3.0 + 10000*11.25
if math.Abs(result.ActualQuotaBeforeGroup-wantBefore) > 1e-6 {
t.Errorf("before group: got %f, want %f", result.ActualQuotaBeforeGroup, wantBefore)
}
if result.MatchedTier != "long_context" {
t.Errorf("tier = %q, want %q", result.MatchedTier, "long_context")
}
if !result.CrossedTier {
t.Error("expected crossed_tier=true (estimated standard, actual long_context)")
}
}
func TestComputeTieredQuota_SameTier(t *testing.T) {
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: claudeExpr,
ExprHash: billingexpr.ExprHashString(claudeExpr),
GroupRatio: 1.5,
EstimatedPromptTokens: 50000,
EstimatedCompletionTokens: 1000,
EstimatedQuotaBeforeGroup: 50000*1.5 + 1000*7.5,
EstimatedQuotaAfterGroup: billingexpr.QuotaRound((50000*1.5 + 1000*7.5) * 1.5),
EstimatedTier: "standard",
}
result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 80000, C: 2000})
if err != nil {
t.Fatal(err)
}
wantBefore := 80000*1.5 + 2000*7.5
wantAfter := billingexpr.QuotaRound(wantBefore * 1.5)
if result.ActualQuotaAfterGroup != wantAfter {
t.Errorf("after group: got %d, want %d", result.ActualQuotaAfterGroup, wantAfter)
}
if result.CrossedTier {
t.Error("expected crossed_tier=false (both standard)")
}
}
// ---------------------------------------------------------------------------
// Compile errors
// ---------------------------------------------------------------------------
func TestCompileError(t *testing.T) {
_, _, err := billingexpr.RunExpr("invalid +-+ syntax", billingexpr.TokenParams{})
if err == nil {
t.Error("expected compile error")
}
}
// ---------------------------------------------------------------------------
// Compile Cache
// ---------------------------------------------------------------------------
func TestCompileCache_SameResult(t *testing.T) {
r1, _, err := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100})
if err != nil {
t.Fatal(err)
}
r2, _, err := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100})
if err != nil {
t.Fatal(err)
}
if r1 != r2 {
t.Errorf("cached and uncached results differ: %f != %f", r1, r2)
}
}
func TestInvalidateCache(t *testing.T) {
billingexpr.InvalidateCache()
r1, _, _ := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100})
billingexpr.InvalidateCache()
r2, _, _ := billingexpr.RunExpr("p * 0.5", billingexpr.TokenParams{P: 100})
if r1 != r2 {
t.Errorf("post-invalidate results differ: %f != %f", r1, r2)
}
}
// ---------------------------------------------------------------------------
// Hash
// ---------------------------------------------------------------------------
func TestExprHashString_Deterministic(t *testing.T) {
h1 := billingexpr.ExprHashString("p * 0.5")
h2 := billingexpr.ExprHashString("p * 0.5")
if h1 != h2 {
t.Error("hash should be deterministic")
}
h3 := billingexpr.ExprHashString("p * 0.6")
if h1 == h3 {
t.Error("different expressions should have different hashes")
}
}
// ---------------------------------------------------------------------------
// Cache variables: present
// ---------------------------------------------------------------------------
const claudeWithCacheExpr = `p <= 200000 ? tier("standard", p * 1.5 + c * 7.5 + cr * 0.15 + cc * 1.875) : tier("long_context", p * 3.0 + c * 11.25 + cr * 0.3 + cc * 3.75)`
func TestCachePresent_StandardTier(t *testing.T) {
params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 50000, CC: 10000}
cost, trace, err := billingexpr.RunExpr(claudeWithCacheExpr, params)
if err != nil {
t.Fatal(err)
}
want := 100000*1.5 + 5000*7.5 + 50000*0.15 + 10000*1.875
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
if trace.MatchedTier != "standard" {
t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard")
}
}
func TestCachePresent_LongContextTier(t *testing.T) {
params := billingexpr.TokenParams{P: 300000, C: 10000, CR: 100000, CC: 20000}
cost, trace, err := billingexpr.RunExpr(claudeWithCacheExpr, params)
if err != nil {
t.Fatal(err)
}
want := 300000*3.0 + 10000*11.25 + 100000*0.3 + 20000*3.75
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
if trace.MatchedTier != "long_context" {
t.Errorf("tier = %q, want %q", trace.MatchedTier, "long_context")
}
}
// ---------------------------------------------------------------------------
// Cache variables: absent (all zero) — same expression still works
// ---------------------------------------------------------------------------
func TestCacheAbsent_ZeroCacheTokens(t *testing.T) {
params := billingexpr.TokenParams{P: 100000, C: 5000}
cost, trace, err := billingexpr.RunExpr(claudeWithCacheExpr, params)
if err != nil {
t.Fatal(err)
}
want := 100000*1.5 + 5000*7.5
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f (cache terms should be 0)", cost, want)
}
if trace.MatchedTier != "standard" {
t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard")
}
}
// ---------------------------------------------------------------------------
// Mixed cache fields: cc and cc1h non-zero
// ---------------------------------------------------------------------------
const claudeCacheSplitExpr = `tier("default", p * 1.5 + c * 7.5 + cr * 0.15 + cc * 2.0 + cc1h * 3.0)`
func TestMixedCacheFields(t *testing.T) {
params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 10000, CC: 5000, CC1h: 2000}
cost, _, err := billingexpr.RunExpr(claudeCacheSplitExpr, params)
if err != nil {
t.Fatal(err)
}
want := 100000*1.5 + 5000*7.5 + 10000*0.15 + 5000*2.0 + 2000*3.0
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f", cost, want)
}
}
func TestMixedCacheFields_AllCacheZero(t *testing.T) {
params := billingexpr.TokenParams{P: 100000, C: 5000}
cost, _, err := billingexpr.RunExpr(claudeCacheSplitExpr, params)
if err != nil {
t.Fatal(err)
}
want := 100000*1.5 + 5000*7.5
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f (all cache zero)", cost, want)
}
}
// ---------------------------------------------------------------------------
// Backward compatibility: p+c only expressions still work with TokenParams
// ---------------------------------------------------------------------------
func TestBackwardCompat_OldExprWithTokenParams(t *testing.T) {
params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 99999, CC: 88888}
cost, trace, err := billingexpr.RunExpr(claudeExpr, params)
if err != nil {
t.Fatal(err)
}
want := 100000*1.5 + 5000*7.5
if math.Abs(cost-want) > 1e-6 {
t.Errorf("cost = %f, want %f (old expr ignores cache fields)", cost, want)
}
if trace.MatchedTier != "standard" {
t.Errorf("tier = %q, want %q", trace.MatchedTier, "standard")
}
}
// ---------------------------------------------------------------------------
// Settlement with cache tokens
// ---------------------------------------------------------------------------
func TestComputeTieredQuota_WithCache(t *testing.T) {
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: claudeWithCacheExpr,
ExprHash: billingexpr.ExprHashString(claudeWithCacheExpr),
GroupRatio: 1.0,
EstimatedPromptTokens: 100000,
EstimatedCompletionTokens: 5000,
EstimatedQuotaBeforeGroup: 100000*1.5 + 5000*7.5,
EstimatedQuotaAfterGroup: billingexpr.QuotaRound(100000*1.5 + 5000*7.5),
EstimatedTier: "standard",
}
params := billingexpr.TokenParams{P: 100000, C: 5000, CR: 50000, CC: 10000}
result, err := billingexpr.ComputeTieredQuota(snap, params)
if err != nil {
t.Fatal(err)
}
wantBefore := 100000*1.5 + 5000*7.5 + 50000*0.15 + 10000*1.875
if math.Abs(result.ActualQuotaBeforeGroup-wantBefore) > 1e-6 {
t.Errorf("before group: got %f, want %f", result.ActualQuotaBeforeGroup, wantBefore)
}
if result.MatchedTier != "standard" {
t.Errorf("tier = %q, want %q", result.MatchedTier, "standard")
}
if result.CrossedTier {
t.Error("expected crossed_tier=false (same tier)")
}
}
func TestComputeTieredQuota_WithCacheCrossTier(t *testing.T) {
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: claudeWithCacheExpr,
ExprHash: billingexpr.ExprHashString(claudeWithCacheExpr),
GroupRatio: 2.0,
EstimatedPromptTokens: 100000,
EstimatedCompletionTokens: 5000,
EstimatedQuotaBeforeGroup: 100000*1.5 + 5000*7.5,
EstimatedQuotaAfterGroup: billingexpr.QuotaRound((100000*1.5 + 5000*7.5) * 2.0),
EstimatedTier: "standard",
}
params := billingexpr.TokenParams{P: 300000, C: 10000, CR: 50000, CC: 10000}
result, err := billingexpr.ComputeTieredQuota(snap, params)
if err != nil {
t.Fatal(err)
}
wantBefore := 300000*3.0 + 10000*11.25 + 50000*0.3 + 10000*3.75
wantAfter := billingexpr.QuotaRound(wantBefore * 2.0)
if math.Abs(result.ActualQuotaBeforeGroup-wantBefore) > 1e-6 {
t.Errorf("before group: got %f, want %f", result.ActualQuotaBeforeGroup, wantBefore)
}
if result.ActualQuotaAfterGroup != wantAfter {
t.Errorf("after group: got %d, want %d", result.ActualQuotaAfterGroup, wantAfter)
}
if !result.CrossedTier {
t.Error("expected crossed_tier=true (estimated standard, actual long_context)")
}
}
// ---------------------------------------------------------------------------
// Fuzz: random p/c/cache, verify non-negative result
// ---------------------------------------------------------------------------
func TestFuzz_NonNegativeResults(t *testing.T) {
exprs := []string{
claudeExpr,
claudeWithCacheExpr,
glmExpr,
"p * 0.5 + c * 1.0",
"max(p, c) * 0.1",
"p * 0.5 + cr * 0.1 + cc * 0.2",
}
rng := rand.New(rand.NewSource(42))
for _, exprStr := range exprs {
for i := 0; i < 500; i++ {
params := billingexpr.TokenParams{
P: math.Round(rng.Float64() * 1000000),
C: math.Round(rng.Float64() * 500000),
CR: math.Round(rng.Float64() * 200000),
CC: math.Round(rng.Float64() * 50000),
CC1h: math.Round(rng.Float64() * 10000),
}
cost, _, err := billingexpr.RunExpr(exprStr, params)
if err != nil {
t.Fatalf("expr=%q params=%+v: %v", exprStr, params, err)
}
if cost < 0 {
t.Errorf("expr=%q params=%+v: negative cost %f", exprStr, params, cost)
}
}
}
}
func TestFuzz_SettlementConsistency(t *testing.T) {
rng := rand.New(rand.NewSource(99))
for i := 0; i < 200; i++ {
estParams := billingexpr.TokenParams{
P: math.Round(rng.Float64() * 500000),
C: math.Round(rng.Float64() * 100000),
CR: math.Round(rng.Float64() * 100000),
CC: math.Round(rng.Float64() * 30000),
}
actParams := billingexpr.TokenParams{
P: math.Round(rng.Float64() * 500000),
C: math.Round(rng.Float64() * 100000),
CR: math.Round(rng.Float64() * 100000),
CC: math.Round(rng.Float64() * 30000),
}
groupRatio := 0.5 + rng.Float64()*2.0
estCost, estTrace, _ := billingexpr.RunExpr(claudeWithCacheExpr, estParams)
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: claudeWithCacheExpr,
ExprHash: billingexpr.ExprHashString(claudeWithCacheExpr),
GroupRatio: groupRatio,
EstimatedPromptTokens: int(estParams.P),
EstimatedCompletionTokens: int(estParams.C),
EstimatedQuotaBeforeGroup: estCost,
EstimatedQuotaAfterGroup: billingexpr.QuotaRound(estCost * groupRatio),
EstimatedTier: estTrace.MatchedTier,
}
result, err := billingexpr.ComputeTieredQuota(snap, actParams)
if err != nil {
t.Fatalf("iter %d: %v", i, err)
}
directCost, _, _ := billingexpr.RunExpr(claudeWithCacheExpr, actParams)
directQuota := billingexpr.QuotaRound(directCost * groupRatio)
if result.ActualQuotaAfterGroup != directQuota {
t.Errorf("iter %d: settlement %d != direct %d", i, result.ActualQuotaAfterGroup, directQuota)
}
}
}
// ---------------------------------------------------------------------------
// Settlement-level tests for ComputeTieredQuota
// ---------------------------------------------------------------------------
func TestComputeTieredQuota_BasicSettlement(t *testing.T) {
exprStr := `tier("default", p + c)`
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: exprStr,
ExprHash: billingexpr.ExprHashString(exprStr),
GroupRatio: 1.0,
}
result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 3000, C: 2000})
if err != nil {
t.Fatal(err)
}
if math.Abs(result.ActualQuotaBeforeGroup-5000) > 1e-6 {
t.Errorf("before group = %f, want 5000", result.ActualQuotaBeforeGroup)
}
if result.ActualQuotaAfterGroup != 5000 {
t.Errorf("after group = %d, want 5000", result.ActualQuotaAfterGroup)
}
if result.MatchedTier != "default" {
t.Errorf("tier = %q, want default", result.MatchedTier)
}
}
func TestComputeTieredQuota_WithGroupRatio(t *testing.T) {
exprStr := `tier("default", p + c)`
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: exprStr,
ExprHash: billingexpr.ExprHashString(exprStr),
GroupRatio: 2.0,
}
result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 1000, C: 500})
if err != nil {
t.Fatal(err)
}
// cost = 1500, after group = round(1500 * 2.0) = 3000
if result.ActualQuotaAfterGroup != 3000 {
t.Errorf("after group = %d, want 3000", result.ActualQuotaAfterGroup)
}
}
func TestComputeTieredQuota_ZeroTokens(t *testing.T) {
exprStr := `tier("default", p * 2 + c * 10)`
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: exprStr,
ExprHash: billingexpr.ExprHashString(exprStr),
GroupRatio: 1.0,
}
result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{})
if err != nil {
t.Fatal(err)
}
if result.ActualQuotaAfterGroup != 0 {
t.Errorf("after group = %d, want 0", result.ActualQuotaAfterGroup)
}
}
func TestComputeTieredQuota_RoundingEdge(t *testing.T) {
exprStr := `tier("default", p * 0.5)` // 3 * 0.5 = 1.5 -> round to 2
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: exprStr,
ExprHash: billingexpr.ExprHashString(exprStr),
GroupRatio: 1.0,
}
result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 3})
if err != nil {
t.Fatal(err)
}
// 3 * 0.5 = 1.5, round(1.5) = 2
if result.ActualQuotaAfterGroup != 2 {
t.Errorf("after group = %d, want 2 (round 1.5 up)", result.ActualQuotaAfterGroup)
}
}
func TestComputeTieredQuota_RoundingEdgeDown(t *testing.T) {
exprStr := `tier("default", p * 0.4)` // 3 * 0.4 = 1.2 -> round to 1
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: exprStr,
ExprHash: billingexpr.ExprHashString(exprStr),
GroupRatio: 1.0,
}
result, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 3})
if err != nil {
t.Fatal(err)
}
// 3 * 0.4 = 1.2, round(1.2) = 1
if result.ActualQuotaAfterGroup != 1 {
t.Errorf("after group = %d, want 1 (round 1.2 down)", result.ActualQuotaAfterGroup)
}
}
func TestComputeTieredQuotaWithRequest_ProbeAffectsQuota(t *testing.T) {
exprStr := `param("fast") == true ? tier("fast", p * 4) : tier("normal", p * 2)`
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: exprStr,
ExprHash: billingexpr.ExprHashString(exprStr),
GroupRatio: 1.0,
EstimatedTier: "normal",
}
// Without request: normal tier
r1, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 1000})
if err != nil {
t.Fatal(err)
}
if r1.ActualQuotaAfterGroup != 2000 {
t.Errorf("normal = %d, want 2000", r1.ActualQuotaAfterGroup)
}
// With request: fast tier
r2, err := billingexpr.ComputeTieredQuotaWithRequest(snap, billingexpr.TokenParams{P: 1000}, billingexpr.RequestInput{
Body: []byte(`{"fast":true}`),
})
if err != nil {
t.Fatal(err)
}
if r2.ActualQuotaAfterGroup != 4000 {
t.Errorf("fast = %d, want 4000", r2.ActualQuotaAfterGroup)
}
if !r2.CrossedTier {
t.Error("expected CrossedTier = true when probe changes tier")
}
}
func TestComputeTieredQuota_BoundaryTierCrossing(t *testing.T) {
exprStr := `p <= 100000 ? tier("small", p * 1) : tier("large", p * 2)`
snap := &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: exprStr,
ExprHash: billingexpr.ExprHashString(exprStr),
GroupRatio: 1.0,
EstimatedTier: "small",
}
// At boundary
r1, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 100000})
if err != nil {
t.Fatal(err)
}
if r1.MatchedTier != "small" {
t.Errorf("at boundary: tier = %s, want small", r1.MatchedTier)
}
if r1.ActualQuotaAfterGroup != 100000 {
t.Errorf("at boundary: quota = %d, want 100000", r1.ActualQuotaAfterGroup)
}
// Past boundary
r2, err := billingexpr.ComputeTieredQuota(snap, billingexpr.TokenParams{P: 100001})
if err != nil {
t.Fatal(err)
}
if r2.MatchedTier != "large" {
t.Errorf("past boundary: tier = %s, want large", r2.MatchedTier)
}
if r2.ActualQuotaAfterGroup != 200002 {
t.Errorf("past boundary: quota = %d, want 200002", r2.ActualQuotaAfterGroup)
}
if !r2.CrossedTier {
t.Error("expected CrossedTier = true")
}
}
// ---------------------------------------------------------------------------
// Time function tests
// ---------------------------------------------------------------------------
func TestTimeFunctions_ValidTimezone(t *testing.T) {
exprStr := `tier("default", p) * (hour("UTC") >= 0 ? 1 : 1)`
cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100})
if err != nil {
t.Fatal(err)
}
if cost != 100 {
t.Errorf("cost = %f, want 100", cost)
}
}
func TestTimeFunctions_AllFunctionsCompile(t *testing.T) {
exprStr := `tier("default", p) * (hour("Asia/Shanghai") >= 0 ? 1 : 1) * (minute("UTC") >= 0 ? 1 : 1) * (weekday("UTC") >= 0 ? 1 : 1) * (month("UTC") >= 1 ? 1 : 1) * (day("UTC") >= 1 ? 1 : 1)`
cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 500})
if err != nil {
t.Fatal(err)
}
if cost != 500 {
t.Errorf("cost = %f, want 500", cost)
}
}
func TestTimeFunctions_InvalidTimezone(t *testing.T) {
exprStr := `tier("default", p) * (hour("Invalid/Zone") >= 0 ? 1 : 2)`
cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100})
if err != nil {
t.Fatal(err)
}
// Invalid timezone falls back to UTC; hour is 0-23, so condition is always true
if cost != 100 {
t.Errorf("cost = %f, want 100 (fallback to UTC)", cost)
}
}
func TestTimeFunctions_EmptyTimezone(t *testing.T) {
exprStr := `tier("default", p) * (hour("") >= 0 ? 1 : 2)`
cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100})
if err != nil {
t.Fatal(err)
}
if cost != 100 {
t.Errorf("cost = %f, want 100 (empty tz -> UTC)", cost)
}
}
func TestTimeFunctions_NightDiscountPattern(t *testing.T) {
exprStr := `tier("default", p * 2 + c * 10) * (hour("UTC") >= 21 || hour("UTC") < 6 ? 0.5 : 1)`
cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 1000, C: 500})
if err != nil {
t.Fatal(err)
}
// Base = 1000*2 + 500*10 = 7000; multiplier is either 0.5 or 1 depending on current UTC hour
if cost != 7000 && cost != 3500 {
t.Errorf("cost = %f, want 7000 or 3500", cost)
}
}
func TestTimeFunctions_WeekdayRange(t *testing.T) {
exprStr := `tier("default", p) * (weekday("UTC") >= 0 && weekday("UTC") <= 6 ? 1 : 999)`
cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 100})
if err != nil {
t.Fatal(err)
}
// weekday is always 0-6, so multiplier is always 1
if cost != 100 {
t.Errorf("cost = %f, want 100", cost)
}
}
func TestTimeFunctions_MonthDayPattern(t *testing.T) {
exprStr := `tier("default", p) * (month("Asia/Shanghai") == 1 && day("Asia/Shanghai") == 1 ? 0.5 : 1)`
cost, _, err := billingexpr.RunExpr(exprStr, billingexpr.TokenParams{P: 1000})
if err != nil {
t.Fatal(err)
}
// Either 1000 (not Jan 1) or 500 (Jan 1) — both are valid
if cost != 1000 && cost != 500 {
t.Errorf("cost = %f, want 1000 or 500", cost)
}
}
+91
View File
@@ -0,0 +1,91 @@
package billingexpr
import (
"fmt"
"math"
"sync"
"github.com/expr-lang/expr"
"github.com/expr-lang/expr/vm"
)
const maxCacheSize = 256
var (
cacheMu sync.RWMutex
cache = make(map[string]*vm.Program, 64)
)
// compileEnvPrototype is the type-checking prototype used at compile time.
// It declares the shape of the environment that RunExpr will provide.
// The tier() function is a no-op placeholder here; the real one with
// side-channel tracing is injected at runtime.
var compileEnvPrototype = map[string]interface{}{
"p": float64(0),
"c": float64(0),
"cr": float64(0),
"cc": float64(0),
"cc1h": float64(0),
"prompt_tokens": float64(0),
"completion_tokens": float64(0),
"cache_read_tokens": float64(0),
"cache_create_tokens": float64(0),
"cache_create_1h_tokens": float64(0),
"tier": func(string, float64) float64 { return 0 },
"header": func(string) string { return "" },
"param": func(string) interface{} { return nil },
"has": func(interface{}, string) bool { return false },
"hour": func(string) int { return 0 },
"minute": func(string) int { return 0 },
"weekday": func(string) int { return 0 },
"month": func(string) int { return 0 },
"day": func(string) int { return 0 },
"max": math.Max,
"min": math.Min,
"abs": math.Abs,
"ceil": math.Ceil,
"floor": math.Floor,
}
// CompileFromCache compiles an expression string, using a cached program when
// available. The cache is keyed by the SHA-256 hex digest of the expression.
func CompileFromCache(exprStr string) (*vm.Program, error) {
return compileFromCacheByHash(exprStr, ExprHashString(exprStr))
}
// CompileFromCacheByHash is like CompileFromCache but accepts a pre-computed
// hash, useful when the caller already has the BillingSnapshot.ExprHash.
func CompileFromCacheByHash(exprStr, hash string) (*vm.Program, error) {
return compileFromCacheByHash(exprStr, hash)
}
func compileFromCacheByHash(exprStr, hash string) (*vm.Program, error) {
cacheMu.RLock()
if prog, ok := cache[hash]; ok {
cacheMu.RUnlock()
return prog, nil
}
cacheMu.RUnlock()
prog, err := expr.Compile(exprStr, expr.Env(compileEnvPrototype), expr.AsFloat64())
if err != nil {
return nil, fmt.Errorf("expr compile error: %w", err)
}
cacheMu.Lock()
if len(cache) >= maxCacheSize {
cache = make(map[string]*vm.Program, 64)
}
cache[hash] = prog
cacheMu.Unlock()
return prog, nil
}
// InvalidateCache clears the compiled-expression cache.
// Called when billing rules are updated.
func InvalidateCache() {
cacheMu.Lock()
cache = make(map[string]*vm.Program, 64)
cacheMu.Unlock()
}
+10
View File
@@ -0,0 +1,10 @@
package billingexpr
import "math"
// QuotaRound converts a float64 quota value to int using half-away-from-zero
// rounding. Every tiered billing path (pre-consume, settlement, breakdown
// validation, log fields) MUST use this function to avoid +-1 discrepancies.
func QuotaRound(f float64) int {
return int(math.Round(f))
}
+139
View File
@@ -0,0 +1,139 @@
package billingexpr
import (
"fmt"
"math"
"strings"
"time"
"github.com/expr-lang/expr"
"github.com/expr-lang/expr/vm"
"github.com/tidwall/gjson"
)
// RunExpr compiles (with cache) and executes an expression string.
// The environment exposes:
// - p, c — prompt / completion tokens
// - cr, cc, cc1h — cache read / creation / creation-1h tokens
// - tier(name, value) — trace callback that records which tier matched
// - max, min, abs, ceil, floor — standard math helpers
//
// Returns the resulting float64 quota (before group ratio) and a TraceResult
// with side-channel info captured by tier() during execution.
func RunExpr(exprStr string, params TokenParams) (float64, TraceResult, error) {
return RunExprWithRequest(exprStr, params, RequestInput{})
}
func RunExprWithRequest(exprStr string, params TokenParams, request RequestInput) (float64, TraceResult, error) {
prog, err := CompileFromCache(exprStr)
if err != nil {
return 0, TraceResult{}, err
}
return runProgram(prog, params, request)
}
// RunExprByHash is like RunExpr but accepts a pre-computed hash for the cache
// lookup, avoiding a redundant SHA-256 computation when the caller already
// holds BillingSnapshot.ExprHash.
func RunExprByHash(exprStr, hash string, params TokenParams) (float64, TraceResult, error) {
return RunExprByHashWithRequest(exprStr, hash, params, RequestInput{})
}
func RunExprByHashWithRequest(exprStr, hash string, params TokenParams, request RequestInput) (float64, TraceResult, error) {
prog, err := CompileFromCacheByHash(exprStr, hash)
if err != nil {
return 0, TraceResult{}, err
}
return runProgram(prog, params, request)
}
func runProgram(prog *vm.Program, params TokenParams, request RequestInput) (float64, TraceResult, error) {
trace := TraceResult{}
headers := normalizeHeaders(request.Headers)
env := map[string]interface{}{
"p": params.P,
"c": params.C,
"cr": params.CR,
"cc": params.CC,
"cc1h": params.CC1h,
"prompt_tokens": params.P,
"completion_tokens": params.C,
"cache_read_tokens": params.CR,
"cache_create_tokens": params.CC,
"cache_create_1h_tokens": params.CC1h,
"tier": func(name string, value float64) float64 {
trace.MatchedTier = name
trace.Cost = value
return value
},
"header": func(key string) string {
return headers[strings.ToLower(strings.TrimSpace(key))]
},
"param": func(path string) interface{} {
path = strings.TrimSpace(path)
if path == "" || len(request.Body) == 0 {
return nil
}
result := gjson.GetBytes(request.Body, path)
if !result.Exists() {
return nil
}
return result.Value()
},
"has": func(source interface{}, substr string) bool {
if source == nil || substr == "" {
return false
}
return strings.Contains(fmt.Sprint(source), substr)
},
"hour": func(tz string) int { return timeInZone(tz).Hour() },
"minute": func(tz string) int { return timeInZone(tz).Minute() },
"weekday": func(tz string) int { return int(timeInZone(tz).Weekday()) },
"month": func(tz string) int { return int(timeInZone(tz).Month()) },
"day": func(tz string) int { return timeInZone(tz).Day() },
"max": math.Max,
"min": math.Min,
"abs": math.Abs,
"ceil": math.Ceil,
"floor": math.Floor,
}
out, err := expr.Run(prog, env)
if err != nil {
return 0, trace, fmt.Errorf("expr run error: %w", err)
}
f, ok := out.(float64)
if !ok {
return 0, trace, fmt.Errorf("expr result is %T, want float64", out)
}
return f, trace, nil
}
func timeInZone(tz string) time.Time {
tz = strings.TrimSpace(tz)
if tz == "" {
return time.Now().UTC()
}
loc, err := time.LoadLocation(tz)
if err != nil {
return time.Now().UTC()
}
return time.Now().In(loc)
}
func normalizeHeaders(headers map[string]string) map[string]string {
if len(headers) == 0 {
return map[string]string{}
}
normalized := make(map[string]string, len(headers))
for key, value := range headers {
k := strings.ToLower(strings.TrimSpace(key))
v := strings.TrimSpace(value)
if k == "" || v == "" {
continue
}
normalized[k] = v
}
return normalized
}
+24
View File
@@ -0,0 +1,24 @@
package billingexpr
// ComputeTieredQuota runs the Expr from a frozen BillingSnapshot against
// actual token counts and returns the settlement result.
func ComputeTieredQuota(snap *BillingSnapshot, params TokenParams) (TieredResult, error) {
return ComputeTieredQuotaWithRequest(snap, params, RequestInput{})
}
func ComputeTieredQuotaWithRequest(snap *BillingSnapshot, params TokenParams, request RequestInput) (TieredResult, error) {
cost, trace, err := RunExprByHashWithRequest(snap.ExprString, snap.ExprHash, params, request)
if err != nil {
return TieredResult{}, err
}
afterGroup := QuotaRound(cost * snap.GroupRatio)
crossed := trace.MatchedTier != snap.EstimatedTier
return TieredResult{
ActualQuotaBeforeGroup: cost,
ActualQuotaAfterGroup: afterGroup,
MatchedTier: trace.MatchedTier,
CrossedTier: crossed,
}, nil
}
+59
View File
@@ -0,0 +1,59 @@
package billingexpr
import (
"crypto/sha256"
"fmt"
)
type RequestInput struct {
Headers map[string]string
Body []byte
}
// TokenParams holds all token dimensions passed into an Expr evaluation.
// Fields beyond P and C are optional — when absent they default to 0,
// which means cache-unaware expressions keep working unchanged.
type TokenParams struct {
P float64 // prompt tokens
C float64 // completion tokens
CR float64 // cache read (hit) tokens
CC float64 // cache creation tokens (5-min TTL for Claude, generic for others)
CC1h float64 // cache creation tokens — 1-hour TTL (Claude only)
}
// TraceResult holds side-channel info captured by the tier() function
// during Expr execution. This replaces the old Breakdown mechanism —
// the Expr itself is the single source of truth for billing logic.
type TraceResult struct {
MatchedTier string `json:"matched_tier"`
Cost float64 `json:"cost"`
}
// BillingSnapshot captures the billing rule state frozen at pre-consume time.
// It is fully serializable and contains no compiled program pointers.
type BillingSnapshot struct {
BillingMode string `json:"billing_mode"`
ModelName string `json:"model_name"`
ExprString string `json:"expr_string"`
ExprHash string `json:"expr_hash"`
GroupRatio float64 `json:"group_ratio"`
EstimatedPromptTokens int `json:"estimated_prompt_tokens"`
EstimatedCompletionTokens int `json:"estimated_completion_tokens"`
EstimatedQuotaBeforeGroup float64 `json:"estimated_quota_before_group"`
EstimatedQuotaAfterGroup int `json:"estimated_quota_after_group"`
EstimatedTier string `json:"estimated_tier"`
}
// TieredResult holds everything needed after running tiered settlement.
type TieredResult struct {
ActualQuotaBeforeGroup float64 `json:"actual_quota_before_group"`
ActualQuotaAfterGroup int `json:"actual_quota_after_group"`
MatchedTier string `json:"matched_tier"`
CrossedTier bool `json:"crossed_tier"`
}
// ExprHashString returns the SHA-256 hex digest of an expression string.
func ExprHashString(expr string) string {
h := sha256.Sum256([]byte(expr))
return fmt.Sprintf("%x", h)
}
+1 -1
View File
@@ -46,7 +46,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
resp, err := adaptor.DoRequest(c, info, ioReader)
if err != nil {
return types.NewError(err, types.ErrorCodeDoRequestFailed)
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
+4 -1
View File
@@ -2,6 +2,7 @@ package relay
import (
"bytes"
"io"
"net/http"
"strings"
@@ -124,8 +125,10 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
var requestBody io.Reader = bytes.NewBuffer(jsonData)
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, info, bytes.NewBuffer(jsonData))
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
+3
View File
@@ -18,4 +18,7 @@ type BillingSettler interface {
// GetPreConsumedQuota 返回实际预扣的额度值(信任用户可能为 0)。
GetPreConsumedQuota() int
// Reserve 将预扣额度补到目标值;若目标值不高于当前预扣额度则不做任何事。
Reserve(targetQuota int) error
}
+6
View File
@@ -10,6 +10,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/pkg/billingexpr"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/types"
@@ -152,6 +153,11 @@ type RelayInfo struct {
PriceData types.PriceData
// TieredBillingSnapshot is a frozen snapshot of tiered billing rules
// captured at pre-consume time. Non-nil only when billing mode is "tiered_expr".
TieredBillingSnapshot *billingexpr.BillingSnapshot
BillingRequestInput *billingexpr.RequestInput
Request dto.Request
// RequestConversionChain records request format conversions in order, e.g.
+61
View File
@@ -13,6 +13,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/pkg/billingexpr"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/relay/helper"
@@ -236,6 +237,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
}
// Tiered billing early return
if ok, tieredQuota, tieredResult := service.TryTieredSettle(relayInfo, billingexpr.TokenParams{
P: float64(usage.PromptTokens),
C: float64(usage.CompletionTokens),
CR: float64(usage.PromptTokensDetails.CachedTokens),
CC: float64(usage.PromptTokensDetails.CachedCreationTokens - usage.ClaudeCacheCreation1hTokens),
CC1h: float64(usage.ClaudeCacheCreation1hTokens),
}); ok {
postConsumeQuotaTiered(ctx, relayInfo, usage, tieredQuota, tieredResult, extraContent...)
return
}
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
@@ -506,3 +519,51 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
Other: other,
})
}
func postConsumeQuotaTiered(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, quota int, tieredResult *service.TieredResultWrapper, extraContent ...string) {
_ = tieredResult // will be used for log enrichment
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
totalTokens := usage.PromptTokens + usage.CompletionTokens
if totalTokens == 0 {
quota = 0
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
logger.LogError(ctx, fmt.Sprintf("tiered billing: total tokens is 0, userId %d, channelId %d, tokenId %d, model %s, pre-consumed %d",
relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
if groupRatio != 0 && quota == 0 {
quota = 1
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
if err := service.SettleBilling(ctx, relayInfo, quota); err != nil {
logger.LogError(ctx, "error settling tiered billing: "+err.Error())
}
logModel := modelName
logContent := strings.Join(extraContent, ", ")
other := service.GenerateTieredOtherInfo(ctx, relayInfo, tieredResult)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
+2 -1
View File
@@ -3,6 +3,7 @@ package relay
import (
"bytes"
"fmt"
"io"
"net/http"
"github.com/QuantumNous/new-api/common"
@@ -58,7 +59,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
}
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
requestBody := bytes.NewBuffer(jsonData)
var requestBody io.Reader = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
+54
View File
@@ -0,0 +1,54 @@
package helper
import (
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/pkg/billingexpr"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/gin-gonic/gin"
)
func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.RelayInfo) (billingexpr.RequestInput, error) {
input := billingexpr.RequestInput{}
if info != nil {
input.Headers = cloneStringMap(info.RequestHeaders)
}
bodyBytes, err := readIncomingBillingExprBody(c)
if err != nil {
return billingexpr.RequestInput{}, err
}
input.Body = bodyBytes
return input, nil
}
func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) {
if c == nil || c.Request == nil || !isJSONContentType(c.Request.Header.Get("Content-Type")) {
return nil, nil
}
storage, err := common.GetBodyStorage(c)
if err != nil {
return nil, err
}
return storage.Bytes()
}
func isJSONContentType(contentType string) bool {
contentType = strings.ToLower(strings.TrimSpace(contentType))
return strings.HasPrefix(contentType, "application/json")
}
func cloneStringMap(src map[string]string) map[string]string {
if len(src) == 0 {
return map[string]string{}
}
dst := make(map[string]string, len(src))
for key, value := range src {
if strings.TrimSpace(key) == "" {
continue
}
dst[key] = value
}
return dst
}
+35
View File
@@ -0,0 +1,35 @@
package helper
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/QuantumNous/new-api/common"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestResolveIncomingBillingExprRequestInput(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
ctx.Request.Header.Set("Content-Type", "application/json")
body := []byte(`{"service_tier":"fast"}`)
ctx.Request.Body = io.NopCloser(bytes.NewReader(body))
ctx.Set(common.KeyRequestBody, body)
info := &relaycommon.RelayInfo{
RequestHeaders: map[string]string{"Content-Type": "application/json"},
}
input, err := ResolveIncomingBillingExprRequestInput(ctx, info)
require.NoError(t, err)
require.Equal(t, body, input.Body)
require.Equal(t, "application/json", input.Headers["Content-Type"])
}
+75
View File
@@ -5,7 +5,9 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/pkg/billingexpr"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/billing_setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
@@ -50,6 +52,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
groupRatioInfo := HandleGroupRatio(c, info)
// Check if this model uses tiered_expr billing
if billing_setting.GetBillingMode(info.OriginModelName) == billing_setting.BillingModeTieredExpr {
return modelPriceHelperTiered(c, info, promptTokens, meta, groupRatioInfo)
}
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
@@ -195,5 +202,73 @@ func ContainPriceOrRatio(modelName string) bool {
if ok {
return true
}
if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr {
_, ok = billing_setting.GetBillingExpr(modelName)
return ok
}
return false
}
func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) {
exprStr, ok := billing_setting.GetBillingExpr(info.OriginModelName)
if !ok {
return types.PriceData{}, fmt.Errorf("model %s is configured as tiered_expr but has no billing expression", info.OriginModelName)
}
estimatedCompletionTokens := 0
if meta.MaxTokens != 0 {
estimatedCompletionTokens = meta.MaxTokens
}
requestInput, err := ResolveIncomingBillingExprRequestInput(c, info)
if err != nil {
return types.PriceData{}, err
}
rawQuota, trace, err := billingexpr.RunExprWithRequest(exprStr, billingexpr.TokenParams{
P: float64(promptTokens),
C: float64(estimatedCompletionTokens),
}, requestInput)
if err != nil {
return types.PriceData{}, fmt.Errorf("model %s tiered expr run failed: %w", info.OriginModelName, err)
}
preConsumedQuota := billingexpr.QuotaRound(rawQuota * groupRatioInfo.GroupRatio)
freeModel := false
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
if groupRatioInfo.GroupRatio == 0 || rawQuota == 0 {
preConsumedQuota = 0
freeModel = true
}
}
exprHash := billingexpr.ExprHashString(exprStr)
snapshot := &billingexpr.BillingSnapshot{
BillingMode: billing_setting.BillingModeTieredExpr,
ModelName: info.OriginModelName,
ExprString: exprStr,
ExprHash: exprHash,
GroupRatio: groupRatioInfo.GroupRatio,
EstimatedPromptTokens: promptTokens,
EstimatedCompletionTokens: estimatedCompletionTokens,
EstimatedQuotaBeforeGroup: rawQuota,
EstimatedQuotaAfterGroup: preConsumedQuota,
EstimatedTier: trace.MatchedTier,
}
info.TieredBillingSnapshot = snapshot
info.BillingRequestInput = &requestInput
priceData := types.PriceData{
FreeModel: freeModel,
GroupRatioInfo: groupRatioInfo,
QuotaToPreConsume: preConsumedQuota,
}
if common.DebugEnabled {
println(fmt.Sprintf("model_price_helper_tiered result: model=%s preConsume=%d rawQuota=%.2f groupRatio=%.2f tier=%s", info.OriginModelName, preConsumedQuota, rawQuota, groupRatioInfo.GroupRatio, trace.MatchedTier))
}
info.PriceData = priceData
return priceData, nil
}
+89 -2
View File
@@ -27,6 +27,8 @@ type BillingSession struct {
funding FundingSource
preConsumedQuota int // 实际预扣额度(信任用户可能为 0)
tokenConsumed int // 令牌额度实际扣减量
extraReserved int // 发送前补充预扣的额度(订阅退款时需要单独回滚)
trusted bool // 是否命中信任额度旁路
fundingSettled bool // funding.Settle 已成功,资金来源已提交
settled bool // Settle 全部完成(资金 + 令牌)
refunded bool // Refund 已调用
@@ -97,6 +99,8 @@ func (s *BillingSession) Refund(c *gin.Context) {
tokenKey := s.relayInfo.TokenKey
isPlayground := s.relayInfo.IsPlayground
tokenConsumed := s.tokenConsumed
extraReserved := s.extraReserved
subscriptionId := s.relayInfo.SubscriptionId
funding := s.funding
gopool.Go(func() {
@@ -104,6 +108,11 @@ func (s *BillingSession) Refund(c *gin.Context) {
if err := funding.Refund(); err != nil {
common.SysLog("error refunding billing source: " + err.Error())
}
if extraReserved > 0 && funding.Source() == BillingSourceSubscription && subscriptionId > 0 {
if err := model.PostConsumeUserSubscriptionDelta(subscriptionId, -int64(extraReserved)); err != nil {
common.SysLog("error refunding subscription extra reserved quota: " + err.Error())
}
}
// 2) 退还令牌额度
if tokenConsumed > 0 && !isPlayground {
if err := model.IncreaseTokenQuota(tokenId, tokenKey, tokenConsumed); err != nil {
@@ -140,6 +149,34 @@ func (s *BillingSession) GetPreConsumedQuota() int {
return s.preConsumedQuota
}
func (s *BillingSession) Reserve(targetQuota int) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.settled || s.refunded || s.trusted || targetQuota <= s.preConsumedQuota {
return nil
}
delta := targetQuota - s.preConsumedQuota
if delta <= 0 {
return nil
}
if err := s.reserveFunding(delta); err != nil {
return err
}
if err := s.reserveToken(delta); err != nil {
s.rollbackFundingReserve(delta)
return err
}
s.preConsumedQuota += delta
s.tokenConsumed += delta
s.extraReserved += delta
s.syncRelayInfo()
return nil
}
// ---------------------------------------------------------------------------
// PreConsume — 统一预扣费入口(含信任额度旁路)
// ---------------------------------------------------------------------------
@@ -151,6 +188,7 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro
// ---- 信任额度旁路 ----
if s.shouldTrust(c) {
s.trusted = true
effectiveQuota = 0
logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足, 信任且不需要预扣费 (funding=%s)", s.relayInfo.UserId, s.funding.Source()))
} else if effectiveQuota > 0 {
@@ -191,6 +229,55 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro
return nil
}
func (s *BillingSession) reserveFunding(delta int) error {
switch funding := s.funding.(type) {
case *WalletFunding:
if err := model.DecreaseUserQuota(funding.userId, delta); err != nil {
return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
}
funding.consumed += delta
return nil
case *SubscriptionFunding:
if err := model.PostConsumeUserSubscriptionDelta(funding.subscriptionId, int64(delta)); err != nil {
return types.NewErrorWithStatusCode(
fmt.Errorf("订阅额度不足或未配置订阅: %s", err.Error()),
types.ErrorCodeInsufficientUserQuota,
http.StatusForbidden,
types.ErrOptionWithSkipRetry(),
types.ErrOptionWithNoRecordErrorLog(),
)
}
return nil
default:
return types.NewError(fmt.Errorf("unsupported funding source: %s", s.funding.Source()), types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
}
}
func (s *BillingSession) rollbackFundingReserve(delta int) {
switch funding := s.funding.(type) {
case *WalletFunding:
if err := model.IncreaseUserQuota(funding.userId, delta, false); err != nil {
common.SysLog("error rolling back wallet funding reserve: " + err.Error())
} else {
funding.consumed -= delta
}
case *SubscriptionFunding:
if err := model.PostConsumeUserSubscriptionDelta(funding.subscriptionId, -int64(delta)); err != nil {
common.SysLog("error rolling back subscription funding reserve: " + err.Error())
}
}
}
func (s *BillingSession) reserveToken(delta int) error {
if delta <= 0 || s.relayInfo.IsPlayground {
return nil
}
if err := PreConsumeTokenQuota(s.relayInfo, delta); err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
return nil
}
// shouldTrust 统一信任额度检查,适用于钱包和订阅。
func (s *BillingSession) shouldTrust(c *gin.Context) bool {
// 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路
@@ -235,10 +322,10 @@ func (s *BillingSession) syncRelayInfo() {
if sub, ok := s.funding.(*SubscriptionFunding); ok {
info.SubscriptionId = sub.subscriptionId
info.SubscriptionPreConsumed = sub.preConsumed
info.SubscriptionPreConsumed = sub.preConsumed + int64(s.extraReserved)
info.SubscriptionPostDelta = 0
info.SubscriptionAmountTotal = sub.AmountTotal
info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter
info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter + int64(s.extraReserved)
info.SubscriptionPlanId = sub.PlanId
info.SubscriptionPlanTitle = sub.PlanTitle
} else {
+40
View File
@@ -6,6 +6,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/pkg/billingexpr"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
@@ -214,3 +215,42 @@ func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.Price
appendRequestPath(nil, relayInfo, other)
return other
}
func GenerateTieredOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, result *billingexpr.TieredResult) map[string]interface{} {
other := make(map[string]interface{})
other["billing_mode"] = "tiered_expr"
snap := relayInfo.TieredBillingSnapshot
if snap != nil {
other["group_ratio"] = snap.GroupRatio
other["expr_hash"] = snap.ExprHash
other["estimated_prompt_tokens"] = snap.EstimatedPromptTokens
other["estimated_completion_tokens"] = snap.EstimatedCompletionTokens
other["estimated_quota_before_group"] = snap.EstimatedQuotaBeforeGroup
other["estimated_quota_after_group"] = snap.EstimatedQuotaAfterGroup
other["estimated_tier"] = snap.EstimatedTier
}
if result != nil {
other["actual_quota_before_group"] = result.ActualQuotaBeforeGroup
other["actual_quota_after_group"] = result.ActualQuotaAfterGroup
other["matched_tier"] = result.MatchedTier
other["crossed_tier"] = result.CrossedTier
}
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
if relayInfo.IsModelMapped {
other["is_model_mapped"] = true
other["upstream_model_name"] = relayInfo.UpstreamModelName
}
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
AppendChannelAffinityAdminInfo(ctx, adminInfo)
other["admin_info"] = adminInfo
appendRequestPath(ctx, relayInfo, other)
appendRequestConversionChain(relayInfo, other)
appendBillingInfo(relayInfo, other)
return other
}
+79
View File
@@ -13,6 +13,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/pkg/billingexpr"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
@@ -157,6 +158,15 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.RealtimeUsage, extraContent string) {
// Tiered billing early return
if ok, tieredQuota, tieredResult := TryTieredSettle(relayInfo, billingexpr.TokenParams{
P: float64(usage.InputTokens),
C: float64(usage.OutputTokens),
}); ok {
postConsumeQuotaTieredService(ctx, relayInfo, modelName, usage.InputTokens, usage.OutputTokens, usage.TotalTokens, tieredQuota, tieredResult, extraContent)
return
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
@@ -240,6 +250,18 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
}
// Tiered billing early return
if ok, tieredQuota, tieredResult := TryTieredSettle(relayInfo, billingexpr.TokenParams{
P: float64(usage.PromptTokens),
C: float64(usage.CompletionTokens),
CR: float64(usage.PromptTokensDetails.CachedTokens),
CC: float64(usage.PromptTokensDetails.CachedCreationTokens - usage.ClaudeCacheCreation1hTokens),
CC1h: float64(usage.ClaudeCacheCreation1hTokens),
}); ok {
postConsumeQuotaTieredService(ctx, relayInfo, relayInfo.OriginModelName, usage.PromptTokens, usage.CompletionTokens, usage.PromptTokens+usage.CompletionTokens, tieredQuota, tieredResult, "")
return
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
@@ -360,6 +382,16 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData)
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
// Tiered billing early return
if ok, tieredQuota, tieredResult := TryTieredSettle(relayInfo, billingexpr.TokenParams{
P: float64(usage.PromptTokens),
C: float64(usage.CompletionTokens),
CR: float64(usage.PromptTokensDetails.CachedTokens),
}); ok {
postConsumeQuotaTieredService(ctx, relayInfo, relayInfo.OriginModelName, usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens, tieredQuota, tieredResult, extraContent)
return
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
textOutTokens := usage.CompletionTokenDetails.TextTokens
@@ -607,3 +639,50 @@ func checkAndSendSubscriptionQuotaNotify(relayInfo *relaycommon.RelayInfo) {
}
})
}
func postConsumeQuotaTieredService(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
promptTokens, completionTokens, totalTokens, quota int, tieredResult *TieredResultWrapper, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
tokenName := ctx.GetString("token_name")
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
var logContent string
if totalTokens == 0 {
quota = 0
logContent = "上游没有返回计费信息(可能是上游超时)"
logger.LogError(ctx, fmt.Sprintf("tiered billing: total tokens is 0, userId %d, channelId %d, tokenId %d, model %s, pre-consumed %d",
relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
if groupRatio != 0 && quota == 0 {
quota = 1
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
if err := SettleBilling(ctx, relayInfo, quota); err != nil {
logger.LogError(ctx, "error settling tiered billing: "+err.Error())
}
if extraContent != "" {
logContent += extraContent
}
other := GenerateTieredOtherInfo(ctx, relayInfo, tieredResult)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
ModelName: modelName,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
+36
View File
@@ -0,0 +1,36 @@
package service
import (
"github.com/QuantumNous/new-api/pkg/billingexpr"
relaycommon "github.com/QuantumNous/new-api/relay/common"
)
// TieredResultWrapper wraps billingexpr.TieredResult for use at the service layer.
type TieredResultWrapper = billingexpr.TieredResult
// TryTieredSettle checks if the request uses tiered_expr billing and, if so,
// computes the actual quota using the frozen BillingSnapshot. Returns:
// - ok=true, quota, result when tiered billing applies
// - ok=false, 0, nil when it doesn't (caller should fall through to existing logic)
func TryTieredSettle(relayInfo *relaycommon.RelayInfo, params billingexpr.TokenParams) (ok bool, quota int, result *billingexpr.TieredResult) {
snap := relayInfo.TieredBillingSnapshot
if snap == nil || snap.BillingMode != "tiered_expr" {
return false, 0, nil
}
requestInput := billingexpr.RequestInput{}
if relayInfo.BillingRequestInput != nil {
requestInput = *relayInfo.BillingRequestInput
}
tr, err := billingexpr.ComputeTieredQuotaWithRequest(snap, params, requestInput)
if err != nil {
quota = relayInfo.FinalPreConsumedQuota
if quota <= 0 {
quota = snap.EstimatedQuotaAfterGroup
}
return true, quota, nil
}
return true, tr.ActualQuotaAfterGroup, &tr
}
+401
View File
@@ -0,0 +1,401 @@
package service
import (
"testing"
"github.com/QuantumNous/new-api/pkg/billingexpr"
relaycommon "github.com/QuantumNous/new-api/relay/common"
)
// Claude Sonnet-style tiered expression: standard vs long-context
const sonnetTieredExpr = `p <= 200000 ? tier("standard", p * 1.5 + c * 7.5) : tier("long_context", p * 3 + c * 11.25)`
// Simple flat expression
const flatExpr = `tier("default", p * 2 + c * 10)`
// Expression with cache tokens
const cacheExpr = `tier("default", p * 2 + c * 10 + cr * 0.2 + cc * 2.5 + cc1h * 4)`
// Expression with request probes
const probeExpr = `param("service_tier") == "fast" ? tier("fast", p * 4 + c * 20) : tier("normal", p * 2 + c * 10)`
func makeSnapshot(expr string, groupRatio float64, estPrompt, estCompletion int) *billingexpr.BillingSnapshot {
return &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: expr,
ExprHash: billingexpr.ExprHashString(expr),
GroupRatio: groupRatio,
EstimatedPromptTokens: estPrompt,
EstimatedCompletionTokens: estCompletion,
}
}
func makeRelayInfo(expr string, groupRatio float64, estPrompt, estCompletion int) *relaycommon.RelayInfo {
snap := makeSnapshot(expr, groupRatio, estPrompt, estCompletion)
cost, trace, _ := billingexpr.RunExpr(expr, billingexpr.TokenParams{P: float64(estPrompt), C: float64(estCompletion)})
snap.EstimatedQuotaBeforeGroup = cost
snap.EstimatedQuotaAfterGroup = billingexpr.QuotaRound(cost * groupRatio)
snap.EstimatedTier = trace.MatchedTier
return &relaycommon.RelayInfo{
TieredBillingSnapshot: snap,
FinalPreConsumedQuota: snap.EstimatedQuotaAfterGroup,
}
}
// ---------------------------------------------------------------------------
// Existing tests (preserved)
// ---------------------------------------------------------------------------
func TestTryTieredSettleUsesFrozenRequestInput(t *testing.T) {
exprStr := `param("service_tier") == "fast" ? tier("fast", p * 2) : tier("normal", p)`
relayInfo := &relaycommon.RelayInfo{
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: exprStr,
ExprHash: billingexpr.ExprHashString(exprStr),
GroupRatio: 1.0,
EstimatedPromptTokens: 100,
EstimatedCompletionTokens: 0,
EstimatedQuotaAfterGroup: 100,
},
BillingRequestInput: &billingexpr.RequestInput{
Body: []byte(`{"service_tier":"fast"}`),
},
}
ok, quota, result := TryTieredSettle(relayInfo, billingexpr.TokenParams{P: 100})
if !ok {
t.Fatal("expected tiered settle to apply")
}
if quota != 200 {
t.Fatalf("quota = %d, want 200", quota)
}
if result == nil || result.MatchedTier != "fast" {
t.Fatalf("matched tier = %v, want fast", result)
}
}
func TestTryTieredSettleFallsBackToFrozenPreConsumeOnExprError(t *testing.T) {
relayInfo := &relaycommon.RelayInfo{
FinalPreConsumedQuota: 321,
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: `invalid +-+ expr`,
ExprHash: billingexpr.ExprHashString(`invalid +-+ expr`),
GroupRatio: 1.0,
EstimatedQuotaAfterGroup: 123,
},
}
ok, quota, result := TryTieredSettle(relayInfo, billingexpr.TokenParams{P: 100})
if !ok {
t.Fatal("expected tiered settle to apply")
}
if quota != 321 {
t.Fatalf("quota = %d, want 321", quota)
}
if result != nil {
t.Fatalf("result = %#v, want nil", result)
}
}
// ---------------------------------------------------------------------------
// Pre-consume vs Post-consume consistency
// ---------------------------------------------------------------------------
func TestTryTieredSettle_PreConsumeMatchesPostConsume(t *testing.T) {
info := makeRelayInfo(flatExpr, 1.0, 1000, 500)
params := billingexpr.TokenParams{P: 1000, C: 500}
ok, quota, _ := TryTieredSettle(info, params)
if !ok {
t.Fatal("expected tiered settle")
}
// p*2 + c*10 = 2000 + 5000 = 7000
if quota != 7000 {
t.Fatalf("quota = %d, want 7000", quota)
}
if quota != info.FinalPreConsumedQuota {
t.Fatalf("pre-consume %d != post-consume %d", info.FinalPreConsumedQuota, quota)
}
}
func TestTryTieredSettle_PostConsumeOverPreConsume(t *testing.T) {
info := makeRelayInfo(flatExpr, 1.0, 1000, 500)
preConsumed := info.FinalPreConsumedQuota // 7000
// Actual usage is higher than estimated
params := billingexpr.TokenParams{P: 2000, C: 1000}
ok, quota, _ := TryTieredSettle(info, params)
if !ok {
t.Fatal("expected tiered settle")
}
// p*2 + c*10 = 4000 + 10000 = 14000
if quota != 14000 {
t.Fatalf("quota = %d, want 14000", quota)
}
if quota <= preConsumed {
t.Fatalf("expected supplement: actual %d should > pre-consumed %d", quota, preConsumed)
}
}
func TestTryTieredSettle_PostConsumeUnderPreConsume(t *testing.T) {
info := makeRelayInfo(flatExpr, 1.0, 1000, 500)
preConsumed := info.FinalPreConsumedQuota // 7000
// Actual usage is lower than estimated
params := billingexpr.TokenParams{P: 100, C: 50}
ok, quota, _ := TryTieredSettle(info, params)
if !ok {
t.Fatal("expected tiered settle")
}
// p*2 + c*10 = 200 + 500 = 700
if quota != 700 {
t.Fatalf("quota = %d, want 700", quota)
}
if quota >= preConsumed {
t.Fatalf("expected refund: actual %d should < pre-consumed %d", quota, preConsumed)
}
}
// ---------------------------------------------------------------------------
// Tiered boundary conditions
// ---------------------------------------------------------------------------
func TestTryTieredSettle_ExactBoundary(t *testing.T) {
info := makeRelayInfo(sonnetTieredExpr, 1.0, 200000, 1000)
// p == 200000 => standard tier (p <= 200000)
ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 200000, C: 1000})
if !ok {
t.Fatal("expected tiered settle")
}
// standard: p*1.5 + c*7.5 = 300000 + 7500 = 307500
if quota != 307500 {
t.Fatalf("quota = %d, want 307500", quota)
}
if result.MatchedTier != "standard" {
t.Fatalf("tier = %s, want standard", result.MatchedTier)
}
}
func TestTryTieredSettle_BoundaryPlusOne(t *testing.T) {
info := makeRelayInfo(sonnetTieredExpr, 1.0, 200000, 1000)
// p == 200001 => crosses to long_context tier
ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 200001, C: 1000})
if !ok {
t.Fatal("expected tiered settle")
}
// long_context: p*3 + c*11.25 = 600003 + 11250 = 611253
if quota != 611253 {
t.Fatalf("quota = %d, want 611253", quota)
}
if result.MatchedTier != "long_context" {
t.Fatalf("tier = %s, want long_context", result.MatchedTier)
}
if !result.CrossedTier {
t.Fatal("expected CrossedTier = true")
}
}
func TestTryTieredSettle_ZeroTokens(t *testing.T) {
info := makeRelayInfo(flatExpr, 1.0, 0, 0)
ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 0, C: 0})
if !ok {
t.Fatal("expected tiered settle")
}
if quota != 0 {
t.Fatalf("quota = %d, want 0", quota)
}
if result == nil {
t.Fatal("result should not be nil")
}
}
func TestTryTieredSettle_HugeTokens(t *testing.T) {
info := makeRelayInfo(flatExpr, 1.0, 10000000, 5000000)
ok, quota, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 10000000, C: 5000000})
if !ok {
t.Fatal("expected tiered settle")
}
// p*2 + c*10 = 20000000 + 50000000 = 70000000
if quota != 70000000 {
t.Fatalf("quota = %d, want 70000000", quota)
}
}
func TestTryTieredSettle_CacheTokensAffectSettlement(t *testing.T) {
info := makeRelayInfo(cacheExpr, 1.0, 1000, 500)
// Without cache tokens
ok1, quota1, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500})
if !ok1 {
t.Fatal("expected tiered settle")
}
// p*2 + c*10 + cr*0.2 + cc*2.5 + cc1h*4 = 2000 + 5000 + 0 + 0 + 0 = 7000
// With cache tokens
ok2, quota2, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500, CR: 10000, CC: 5000, CC1h: 2000})
if !ok2 {
t.Fatal("expected tiered settle")
}
// 2000 + 5000 + 10000*0.2 + 5000*2.5 + 2000*4 = 2000 + 5000 + 2000 + 12500 + 8000 = 29500
if quota2 <= quota1 {
t.Fatalf("cache tokens should increase quota: without=%d, with=%d", quota1, quota2)
}
if quota1 != 7000 {
t.Fatalf("no-cache quota = %d, want 7000", quota1)
}
if quota2 != 29500 {
t.Fatalf("cache quota = %d, want 29500", quota2)
}
}
// ---------------------------------------------------------------------------
// Request probe tests
// ---------------------------------------------------------------------------
func TestTryTieredSettle_RequestProbeInfluencesBilling(t *testing.T) {
info := makeRelayInfo(probeExpr, 1.0, 1000, 500)
info.BillingRequestInput = &billingexpr.RequestInput{
Body: []byte(`{"service_tier":"fast"}`),
}
ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500})
if !ok {
t.Fatal("expected tiered settle")
}
// fast: p*4 + c*20 = 4000 + 10000 = 14000
if quota != 14000 {
t.Fatalf("quota = %d, want 14000", quota)
}
if result.MatchedTier != "fast" {
t.Fatalf("tier = %s, want fast", result.MatchedTier)
}
}
func TestTryTieredSettle_NoRequestInput_FallsBackToDefault(t *testing.T) {
info := makeRelayInfo(probeExpr, 1.0, 1000, 500)
// No BillingRequestInput set — param("service_tier") returns nil, not "fast"
ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500})
if !ok {
t.Fatal("expected tiered settle")
}
// normal: p*2 + c*10 = 2000 + 5000 = 7000
if quota != 7000 {
t.Fatalf("quota = %d, want 7000", quota)
}
if result.MatchedTier != "normal" {
t.Fatalf("tier = %s, want normal", result.MatchedTier)
}
}
// ---------------------------------------------------------------------------
// Group ratio tests
// ---------------------------------------------------------------------------
func TestTryTieredSettle_GroupRatioScaling(t *testing.T) {
info := makeRelayInfo(flatExpr, 1.5, 1000, 500)
ok, quota, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500})
if !ok {
t.Fatal("expected tiered settle")
}
// cost = 7000, after group = round(7000 * 1.5) = 10500
if quota != 10500 {
t.Fatalf("quota = %d, want 10500", quota)
}
}
func TestTryTieredSettle_GroupRatioZero(t *testing.T) {
info := makeRelayInfo(flatExpr, 0, 1000, 500)
ok, quota, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500})
if !ok {
t.Fatal("expected tiered settle")
}
if quota != 0 {
t.Fatalf("quota = %d, want 0 (group ratio = 0)", quota)
}
}
// ---------------------------------------------------------------------------
// Ratio mode (negative tests) — TryTieredSettle must return false
// ---------------------------------------------------------------------------
func TestTryTieredSettle_RatioMode_NilSnapshot(t *testing.T) {
info := &relaycommon.RelayInfo{
TieredBillingSnapshot: nil,
}
ok, _, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500})
if ok {
t.Fatal("expected TryTieredSettle to return false when snapshot is nil")
}
}
func TestTryTieredSettle_RatioMode_WrongBillingMode(t *testing.T) {
info := &relaycommon.RelayInfo{
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
BillingMode: "ratio",
ExprString: flatExpr,
ExprHash: billingexpr.ExprHashString(flatExpr),
GroupRatio: 1.0,
},
}
ok, _, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500})
if ok {
t.Fatal("expected TryTieredSettle to return false for ratio billing mode")
}
}
func TestTryTieredSettle_RatioMode_EmptyBillingMode(t *testing.T) {
info := &relaycommon.RelayInfo{
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
BillingMode: "",
ExprString: flatExpr,
ExprHash: billingexpr.ExprHashString(flatExpr),
GroupRatio: 1.0,
},
}
ok, _, _ := TryTieredSettle(info, billingexpr.TokenParams{P: 1000, C: 500})
if ok {
t.Fatal("expected TryTieredSettle to return false for empty billing mode")
}
}
// ---------------------------------------------------------------------------
// Fallback tests
// ---------------------------------------------------------------------------
func TestTryTieredSettle_ErrorFallbackToEstimatedQuotaAfterGroup(t *testing.T) {
info := &relaycommon.RelayInfo{
FinalPreConsumedQuota: 0,
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: `invalid expr!!!`,
ExprHash: billingexpr.ExprHashString(`invalid expr!!!`),
GroupRatio: 1.0,
EstimatedQuotaAfterGroup: 999,
},
}
ok, quota, result := TryTieredSettle(info, billingexpr.TokenParams{P: 100})
if !ok {
t.Fatal("expected tiered settle to apply")
}
// FinalPreConsumedQuota is 0, should fall back to EstimatedQuotaAfterGroup
if quota != 999 {
t.Fatalf("quota = %d, want 999", quota)
}
if result != nil {
t.Fatal("result should be nil on error fallback")
}
}
+135
View File
@@ -0,0 +1,135 @@
package billing_setting
import (
"fmt"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/pkg/billingexpr"
)
var (
mu sync.RWMutex
// model -> "ratio" | "tiered_expr"
billingModeMap = make(map[string]string)
// model -> expr string (authored by frontend, stored directly)
billingExprMap = make(map[string]string)
)
const (
BillingModeRatio = "ratio"
BillingModeTieredExpr = "tiered_expr"
)
// ---------------------------------------------------------------------------
// Read accessors (hot path, must be fast)
// ---------------------------------------------------------------------------
func GetBillingMode(model string) string {
mu.RLock()
defer mu.RUnlock()
if mode, ok := billingModeMap[model]; ok {
return mode
}
return BillingModeRatio
}
func GetBillingExpr(model string) (string, bool) {
mu.RLock()
defer mu.RUnlock()
expr, ok := billingExprMap[model]
return expr, ok
}
func UpdateBillingModeByJSONString(jsonStr string) error {
var m map[string]string
if err := common.Unmarshal([]byte(jsonStr), &m); err != nil {
return fmt.Errorf("parse ModelBillingMode: %w", err)
}
for k, v := range m {
if v != BillingModeRatio && v != BillingModeTieredExpr {
return fmt.Errorf("invalid billing mode %q for model %q", v, k)
}
}
mu.Lock()
billingModeMap = m
mu.Unlock()
return nil
}
func UpdateBillingExprByJSONString(jsonStr string) error {
var m map[string]string
if err := common.Unmarshal([]byte(jsonStr), &m); err != nil {
return fmt.Errorf("parse ModelBillingExpr: %w", err)
}
for model, exprStr := range m {
if _, err := billingexpr.CompileFromCache(exprStr); err != nil {
return fmt.Errorf("model %q: %w", model, err)
}
if err := smokeTestExpr(exprStr); err != nil {
return fmt.Errorf("model %q smoke test: %w", model, err)
}
}
mu.Lock()
billingExprMap = m
mu.Unlock()
billingexpr.InvalidateCache()
return nil
}
// ---------------------------------------------------------------------------
// JSON serializers (for OptionMap / API response)
// ---------------------------------------------------------------------------
func BillingMode2JSONString() string {
mu.RLock()
defer mu.RUnlock()
b, err := common.Marshal(billingModeMap)
if err != nil {
return "{}"
}
return string(b)
}
func BillingExpr2JSONString() string {
mu.RLock()
defer mu.RUnlock()
b, err := common.Marshal(billingExprMap)
if err != nil {
return "{}"
}
return string(b)
}
func smokeTestExpr(exprStr string) error {
vectors := []billingexpr.TokenParams{
{P: 0, C: 0},
{P: 1000, C: 1000},
{P: 100000, C: 100000},
{P: 1000000, C: 1000000},
}
requests := []billingexpr.RequestInput{
{},
{
Headers: map[string]string{
"anthropic-beta": "fast-mode-2026-02-01",
},
Body: []byte(`{"service_tier":"fast","stream_options":{"include_usage":true},"messages":[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]}`),
},
}
for _, v := range vectors {
for _, request := range requests {
result, _, err := billingexpr.RunExprWithRequest(exprStr, v, request)
if err != nil {
return fmt.Errorf("vector {p=%g, c=%g}: run failed: %w", v.P, v.C, err)
}
if result < 0 {
return fmt.Errorf("vector {p=%g, c=%g}: result %f < 0", v.P, v.C, result)
}
}
}
return nil
}
+60
View File
@@ -0,0 +1,60 @@
package model_setting
import (
"net/http"
"testing"
)
func TestClaudeSettingsWriteHeadersMergesConfiguredValuesIntoSingleHeader(t *testing.T) {
settings := &ClaudeSettings{
HeadersSettings: map[string]map[string][]string{
"claude-3-7-sonnet-20250219-thinking": {
"anthropic-beta": {
"token-efficient-tools-2025-02-19",
},
},
},
}
headers := http.Header{}
headers.Set("anthropic-beta", "output-128k-2025-02-19")
settings.WriteHeaders("claude-3-7-sonnet-20250219-thinking", &headers)
got := headers.Values("anthropic-beta")
if len(got) != 1 {
t.Fatalf("expected a single merged header value, got %v", got)
}
expected := "output-128k-2025-02-19,token-efficient-tools-2025-02-19"
if got[0] != expected {
t.Fatalf("expected merged header %q, got %q", expected, got[0])
}
}
func TestClaudeSettingsWriteHeadersDeduplicatesAcrossCommaSeparatedAndRepeatedValues(t *testing.T) {
settings := &ClaudeSettings{
HeadersSettings: map[string]map[string][]string{
"claude-3-7-sonnet-20250219-thinking": {
"anthropic-beta": {
"token-efficient-tools-2025-02-19",
"computer-use-2025-01-24",
},
},
},
}
headers := http.Header{}
headers.Add("anthropic-beta", "output-128k-2025-02-19, token-efficient-tools-2025-02-19")
headers.Add("anthropic-beta", "token-efficient-tools-2025-02-19")
settings.WriteHeaders("claude-3-7-sonnet-20250219-thinking", &headers)
got := headers.Values("anthropic-beta")
if len(got) != 1 {
t.Fatalf("expected duplicate values to collapse into one header, got %v", got)
}
expected := "output-128k-2025-02-19,token-efficient-tools-2025-02-19,computer-use-2025-01-24"
if got[0] != expected {
t.Fatalf("expected deduplicated merged header %q, got %q", expected, got[0])
}
}
@@ -377,6 +377,43 @@ function renderCompactDetailSummary(summarySegments) {
);
}
function buildTieredBillingSegments(other, t) {
const segments = [
{ text: `${t('阶梯计费')}`, tone: 'primary' },
];
if (other.matched_tier) {
segments.push({
text: `${t('命中档位')}: ${other.matched_tier}`,
tone: 'secondary',
});
}
const groupRatio = other.group_ratio;
if (groupRatio !== undefined && groupRatio !== null) {
segments.push({
text: `${t('分组')} ${formatRatio(groupRatio)}x`,
tone: 'secondary',
});
}
if (other.crossed_tier) {
segments.push({
text: `${t('跨阶梯')}: ${t('是')}`,
tone: 'secondary',
});
}
if (other.actual_quota_after_group !== undefined) {
segments.push({
text: `${t('实际额度')}: ${other.actual_quota_after_group}`,
tone: 'secondary',
});
}
return { segments };
}
function getUsageLogDetailSummary(record, text, billingDisplayMode, t) {
const other = getLogOther(record.other);
@@ -414,6 +451,10 @@ function getUsageLogDetailSummary(record, text, billingDisplayMode, t) {
};
}
if (other?.billing_mode === 'tiered_expr') {
return buildTieredBillingSegments(other, t);
}
return {
segments: other?.claude
? renderModelPriceSimple(
+60
View File
@@ -559,6 +559,66 @@ export const useLogsData = () => {
value: other.reasoning_effort,
});
}
if (other?.billing_mode === 'tiered_expr') {
expandDataLocal.push({
key: t('计费方式'),
value: t('阶梯计费'),
});
if (other?.group_ratio !== undefined) {
const gr = other.group_ratio;
expandDataLocal.push({
key: t('分组倍率'),
value: typeof gr === 'number' ? gr.toFixed(4) : String(gr ?? '-'),
});
}
if (other?.rule_version !== undefined) {
expandDataLocal.push({
key: t('规则版本'),
value: String(other.rule_version),
});
}
if (other?.estimated_env) {
expandDataLocal.push({
key: t('预估环境'),
value: `prompt=${other.estimated_env.prompt_tokens ?? 0}, completion=${other.estimated_env.completion_tokens ?? 0}`,
});
}
if (other?.actual_env) {
expandDataLocal.push({
key: t('实际环境'),
value: `prompt=${other.actual_env.prompt_tokens ?? 0}, completion=${other.actual_env.completion_tokens ?? 0}`,
});
}
if (other?.estimated_quota_after_group !== undefined) {
expandDataLocal.push({
key: t('预估额度'),
value: String(other.estimated_quota_after_group),
});
}
if (other?.actual_quota_after_group !== undefined) {
expandDataLocal.push({
key: t('实际额度'),
value: String(other.actual_quota_after_group),
});
}
expandDataLocal.push({
key: t('跨阶梯'),
value: other?.crossed_tier ? t('是') : t('否'),
});
if (Array.isArray(other?.breakdown) && other.breakdown.length > 0) {
const breakdownText = other.breakdown.map((item, idx) =>
`[${idx}] ${item.token_type} | tokens=${item.tokens_in_tier} | cost=${item.unit_cost} | flat=${item.flat_fee} | sub=${item.subtotal}`
).join('\n');
expandDataLocal.push({
key: t('计费明细'),
value: (
<div style={{ whiteSpace: 'pre-line', fontFamily: 'monospace', fontSize: 12 }}>
{breakdownText}
</div>
),
});
}
}
}
if (logs[i].type === 6) {
if (other?.task_id) {
+113 -1
View File
@@ -3256,6 +3256,20 @@
"补全价格": "Completion Price",
"缓存读取价格": "Input Cache Read Price",
"缓存创建价格": "Input Cache Creation Price",
"缓存创建价格-5分钟": "Cache Creation Price (5-min)",
"缓存创建价格-1小时": "Cache Creation Price (1-hour)",
"缓存创建价格(5分钟)": "Cache Creation Price (5-min)",
"缓存创建价格(1小时)": "Cache Creation Price (1-hour)",
"分时缓存 (Claude)": "Timed Cache (Claude)",
"通用缓存": "Generic Cache",
"缓存读取": "Cache Read",
"缓存创建": "Cache Creation",
"缓存创建-5分钟": "Cache Creation (5-min)",
"缓存创建-1小时": "Cache Creation (1-hour)",
"缓存读取 Token (cr)": "Cache Read Tokens (cr)",
"缓存创建 Token (cc)": "Cache Creation Tokens (cc)",
"缓存创建-5分钟 (cc5)": "Cache Creation-5min (cc5)",
"缓存创建-1小时 (cc1h)": "Cache Creation-1hour (cc1h)",
"图片输入价格": "Image Input Price",
"音频输入价格": "Audio Input Price",
"音频输入价格:{{symbol}}{{price}} / 1M tokens": "Audio input price: {{symbol}}{{price}} / 1M tokens",
@@ -3309,6 +3323,104 @@
"输入价格:{{symbol}}{{price}} / 1M tokens": "Input Price: {{symbol}}{{price}} / 1M tokens",
"输出价格 {{symbol}}{{price}} / 1M tokens": "Output Price {{symbol}}{{price}} / 1M tokens",
"输出价格:{{symbol}}{{price}} / 1M tokens": "Output Price: {{symbol}}{{price}} / 1M tokens",
"输出价格:{{symbol}}{{total}} / 1M tokens": "Output Price: {{symbol}}{{total}} / 1M tokens"
"输出价格:{{symbol}}{{total}} / 1M tokens": "Output Price: {{symbol}}{{total}} / 1M tokens",
"阶梯计费": "Tiered Billing",
"输入 Tokens 阶梯": "Input Token Tiers",
"输出 Tokens 阶梯": "Output Token Tiers",
"固定阶梯": "Fixed Tier",
"累进阶梯": "Graduated Tier",
"上限": "Up To",
"单价": "Unit Cost",
"固定费": "Flat Fee",
"Expr 预览": "Expression Preview",
"Token 估算器": "Token Estimator",
"预计费用": "Estimated Cost",
"原始额度": "Raw Quota",
"添加阶梯": "Add Tier",
"无限": "Unlimited",
"输入 Token 定价": "Input Token Pricing",
"输出 Token 定价": "Output Token Pricing",
"统一定价": "Flat Rate",
"阶梯累进": "Graduated",
"根据总用量落在哪个档位,所有 Token 都按该档价格计费": "All tokens are charged at the rate of the tier your total usage falls into",
"用量分段计价,每一段各自按对应档位价格计费(类似电费阶梯)": "Usage is charged in segments — each segment at its own tier rate (like utility billing)",
"Token 用量范围": "Token Usage Range",
"所有 Token": "All Tokens",
"前 {{count}} 个": "First {{count}}",
"超过 {{count}} 个": "Over {{count}}",
"第 {{n}} 档": "Tier {{n}}",
"最高档": "Highest Tier",
"此档上限(Token 数)": "Tier Limit (Token Count)",
"每百万 Token 价格": "Price per 1M Tokens",
"进入此档额外收费": "Tier Entry Fee",
"可选,用量达到此档时加收的固定费用": "Optional fixed fee charged when usage reaches this tier",
"添加更多档位": "Add More Tiers",
"输入 Token 数": "Input Tokens",
"输出 Token 数": "Output Tokens",
"输入 Token 数量,查看按当前阶梯配置的预计费用。": "Enter token counts to see the estimated cost with the current tier configuration.",
"开发者": "Developer",
"阶梯计费详情": "Tiered Billing Details",
"预估环境": "Estimated Env",
"实际环境": "Actual Env",
"预估额度": "Estimated Quota",
"实际额度": "Actual Quota",
"跨阶梯": "Crossed Tier",
"是": "Yes",
"否": "No",
"计费明细": "Billing Breakdown",
"阶梯序号": "Tier #",
"Token 类型": "Token Type",
"阶梯内 Token 数": "Tokens in Tier",
"小计": "Subtotal",
"输入": "Input",
"输出": "Output",
"阶梯配置摘要": "Tier Config Summary",
"输入阶梯": "Input Tiers",
"档位名称": "Tier Name",
"用量范围": "Usage Range",
"输入 Token": "Input Token",
"输出 Token": "Output Token",
"阶梯判断依据": "Tier Criterion",
"根据哪个维度的 Token 数量决定落在哪一档": "Determines which tier to apply based on this dimension's token count",
"输入 Token 数 (p)": "Input Tokens (p)",
"输出 Token 数 (c)": "Output Tokens (c)",
"变量": "Variables",
"函数": "Functions",
"输入计费表达式...": "Enter billing expression...",
"表达式编辑": "Expression Editor",
"表达式错误": "Expression Error",
"命中档位": "Matched Tier",
"档": "tier(s)",
"输入 Token 数量,查看按当前配置的预计费用。": "Enter token counts to see the estimated cost.",
"输入 Token 数量,查看按当前配置的预计费用(不含分组倍率)。": "Enter token counts to see the estimated cost (before group ratio).",
"条件": "Condition",
"添加条件": "Add Condition",
"无条件(兜底档)": "No condition (fallback)",
"兜底档": "Fallback",
"预设模板": "Presets",
"每个档位可设置 0~2 个条件(对 p 和 c),最后一档为兜底档无需条件。": "Each tier can have 0-2 conditions (on p and c). The last tier is the fallback and needs no condition.",
"输出阶梯": "Output Tiers",
"阶": "tiers",
"规则版本": "Rule Version",
"时间条件": "Time condition",
"小时": "Hour",
"分钟": "Minute",
"星期": "Weekday",
"月份": "Month",
"日期": "Day",
"时区": "Timezone",
"跨夜范围": "Cross-midnight range",
"添加时间规则": "Add time rule",
"起": "From",
"止": "To",
"值": "Value",
"添加条件组": "Add condition group",
"添加条件": "Add condition",
"添加时间条件": "Add time condition",
"同时满足": "all must match",
"新年促销": "New Year promo",
"第 {{n}} 组": "Group {{n}}",
"0=周日 1=周一 2=周二 3=周三 4=周四 5=周五 6=周六": "0=Sun 1=Mon 2=Tue 3=Wed 4=Thu 5=Fri 6=Sat",
"1=一月 ... 12=十二月": "1=Jan ... 12=Dec"
}
}
+111 -1
View File
@@ -2858,6 +2858,20 @@
"补全价格": "补全价格",
"缓存读取价格": "缓存读取价格",
"缓存创建价格": "缓存创建价格",
"缓存创建价格-5分钟": "缓存创建价格-5分钟",
"缓存创建价格-1小时": "缓存创建价格-1小时",
"缓存创建价格(5分钟)": "缓存创建价格(5分钟)",
"缓存创建价格(1小时)": "缓存创建价格(1小时)",
"分时缓存 (Claude)": "分时缓存 (Claude)",
"通用缓存": "通用缓存",
"缓存读取": "缓存读取",
"缓存创建": "缓存创建",
"缓存创建-5分钟": "缓存创建-5分钟",
"缓存创建-1小时": "缓存创建-1小时",
"缓存读取 Token (cr)": "缓存读取 Token (cr)",
"缓存创建 Token (cc)": "缓存创建 Token (cc)",
"缓存创建-5分钟 (cc5)": "缓存创建-5分钟 (cc5)",
"缓存创建-1小时 (cc1h)": "缓存创建-1小时 (cc1h)",
"图片输入价格": "图片输入价格",
"音频输入价格": "音频输入价格",
"音频补全价格": "音频补全价格",
@@ -2938,6 +2952,102 @@
"输入价格:{{symbol}}{{price}} / 1M tokens": "输入价格:{{symbol}}{{price}} / 1M tokens",
"输出价格 {{symbol}}{{price}} / 1M tokens": "输出价格 {{symbol}}{{price}} / 1M tokens",
"输出价格:{{symbol}}{{price}} / 1M tokens": "输出价格:{{symbol}}{{price}} / 1M tokens",
"输出价格:{{symbol}}{{total}} / 1M tokens": "输出价格:{{symbol}}{{total}} / 1M tokens"
"输出价格:{{symbol}}{{total}} / 1M tokens": "输出价格:{{symbol}}{{total}} / 1M tokens",
"阶梯计费": "阶梯计费",
"输入 Tokens 阶梯": "输入 Tokens 阶梯",
"输出 Tokens 阶梯": "输出 Tokens 阶梯",
"固定阶梯": "固定阶梯",
"累进阶梯": "累进阶梯",
"上限": "上限",
"单价": "单价",
"固定费": "固定费",
"Expr 预览": "Expr 预览",
"Token 估算器": "Token 估算器",
"预计费用": "预计费用",
"添加阶梯": "添加阶梯",
"无限": "无限",
"输入 Token 定价": "输入 Token 定价",
"输出 Token 定价": "输出 Token 定价",
"统一定价": "统一定价",
"阶梯累进": "阶梯累进",
"根据总用量落在哪个档位,所有 Token 都按该档价格计费": "根据总用量落在哪个档位,所有 Token 都按该档价格计费",
"用量分段计价,每一段各自按对应档位价格计费(类似电费阶梯)": "用量分段计价,每一段各自按对应档位价格计费(类似电费阶梯)",
"Token 用量范围": "Token 用量范围",
"所有 Token": "所有 Token",
"前 {{count}} 个": "前 {{count}} 个",
"超过 {{count}} 个": "超过 {{count}} 个",
"第 {{n}} 档": "第 {{n}} 档",
"最高档": "最高档",
"此档上限(Token 数)": "此档上限(Token 数)",
"每百万 Token 价格": "每百万 Token 价格",
"进入此档额外收费": "进入此档额外收费",
"可选,用量达到此档时加收的固定费用": "可选,用量达到此档时加收的固定费用",
"添加更多档位": "添加更多档位",
"输入 Token 数": "输入 Token 数",
"输出 Token 数": "输出 Token 数",
"输入 Token 数量,查看按当前阶梯配置的预计费用。": "输入 Token 数量,查看按当前阶梯配置的预计费用。",
"开发者": "开发者",
"阶梯计费详情": "阶梯计费详情",
"预估环境": "预估环境",
"实际环境": "实际环境",
"预估额度": "预估额度",
"实际额度": "实际额度",
"跨阶梯": "跨阶梯",
"是": "是",
"否": "否",
"计费明细": "计费明细",
"阶梯序号": "阶梯序号",
"Token 类型": "Token 类型",
"阶梯内 Token 数": "阶梯内 Token 数",
"小计": "小计",
"输入": "输入",
"档位标签": "档位标签",
"用量范围": "用量范围",
"输入 Token": "输入 Token",
"输出 Token": "输出 Token",
"阶梯判断依据": "阶梯判断依据",
"根据哪个维度的 Token 数量决定落在哪一档": "根据哪个维度的 Token 数量决定落在哪一档",
"输入 Token 数 (p)": "输入 Token 数 (p)",
"输出 Token 数 (c)": "输出 Token 数 (c)",
"变量": "变量",
"函数": "函数",
"输入计费表达式...": "输入计费表达式...",
"表达式编辑": "表达式编辑",
"表达式错误": "表达式错误",
"命中档位": "命中档位",
"档": "档",
"输入 Token 数量,查看按当前配置的预计费用。": "输入 Token 数量,查看按当前配置的预计费用。",
"条件": "条件",
"添加条件": "添加条件",
"无条件(兜底档)": "无条件(兜底档)",
"兜底档": "兜底档",
"预设模板": "预设模板",
"每个档位可设置 0~2 个条件(对 p 和 c),最后一档为兜底档无需条件。": "每个档位可设置 0~2 个条件(对 p 和 c),最后一档为兜底档无需条件。",
"输出": "输出",
"阶梯配置摘要": "阶梯配置摘要",
"输入阶梯": "输入阶梯",
"输出阶梯": "输出阶梯",
"阶": "阶",
"规则版本": "规则版本",
"时间条件": "时间条件",
"小时": "小时",
"分钟": "分钟",
"星期": "星期",
"月份": "月份",
"日期": "日期",
"时区": "时区",
"跨夜范围": "跨夜范围",
"添加时间规则": "添加时间规则",
"起": "起",
"止": "止",
"值": "值",
"添加条件组": "添加条件组",
"添加条件": "添加条件",
"添加时间条件": "添加时间条件",
"同时满足": "同时满足",
"新年促销": "新年促销",
"第 {{n}} 组": "第 {{n}} 组",
"0=周日 1=周一 2=周二 3=周三 4=周四 5=周五 6=周六": "0=周日 1=周一 2=周二 3=周三 4=周四 5=周五 6=周六",
"1=一月 ... 12=十二月": "1=一月 ... 12=十二月"
}
}
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useMemo, useState } from 'react';
import React, { useCallback, useMemo, useState } from 'react';
import {
Banner,
Button,
@@ -49,6 +49,7 @@ import {
useModelPricingEditorState,
} from '../hooks/useModelPricingEditorState';
import { useIsMobile } from '../../../../hooks/common/useIsMobile';
import TieredPricingEditor from './TieredPricingEditor';
const { Text } = Typography;
const EMPTY_CANDIDATE_MODEL_NAMES = [];
@@ -123,6 +124,8 @@ export default function ModelPricingEditor({
handleOptionalFieldToggle,
handleNumericFieldChange,
handleBillingModeChange,
handleBillingExprChange,
handleRequestRuleExprChange,
handleSubmit,
addModel,
deleteModel,
@@ -135,6 +138,15 @@ export default function ModelPricingEditor({
filterMode,
});
const getExprModeLabel = useCallback((model) => {
if (model?.billingMode !== 'tiered_expr') {
return '';
}
return (model.billingExpr || '').includes('tier(')
? t('阶梯计费')
: t('表达式计费');
}, [t]);
const columns = useMemo(
() => [
{
@@ -175,10 +187,20 @@ export default function ModelPricingEditor({
dataIndex: 'billingMode',
key: 'billingMode',
render: (_, record) => (
<Tag color={record.billingMode === 'per-request' ? 'teal' : 'violet'}>
<Tag
color={
record.billingMode === 'per-request'
? 'teal'
: record.billingMode === 'tiered_expr'
? 'amber'
: 'violet'
}
>
{record.billingMode === 'per-request'
? t('按次计费')
: t('按量计费')}
: record.billingMode === 'tiered_expr'
? getExprModeLabel(record)
: t('按量计费')}
</Tag>
),
},
@@ -208,6 +230,7 @@ export default function ModelPricingEditor({
[
allowDeleteModel,
deleteModel,
getExprModeLabel,
selectedModelName,
selectedModelNames,
setSelectedModelName,
@@ -353,10 +376,20 @@ export default function ModelPricingEditor({
title={selectedModel ? selectedModel.name : t('模型计费编辑器')}
headerExtraContent={
selectedModel ? (
<Tag color='blue'>
<Tag
color={
selectedModel.billingMode === 'per-request'
? 'teal'
: selectedModel.billingMode === 'tiered_expr'
? 'amber'
: 'blue'
}
>
{selectedModel.billingMode === 'per-request'
? t('按次计费')
: t('按量计费')}
: selectedModel.billingMode === 'tiered_expr'
? getExprModeLabel(selectedModel)
: t('按量计费')}
</Tag>
) : null
}
@@ -381,10 +414,11 @@ export default function ModelPricingEditor({
>
<Radio value='per-token'>{t('按量计费')}</Radio>
<Radio value='per-request'>{t('按次计费')}</Radio>
<Radio value='tiered_expr'>{t('表达式/阶梯计费')}</Radio>
</RadioGroup>
<div className='mt-2 text-xs text-gray-500'>
{t(
'这个界面默认按价格填写,保存时会自动换算回后端需要的倍率 JSON。',
'普通按量/按次直接填价格就行;如果价格要跟请求参数或请求头联动,请切到表达式/阶梯计费。',
)}
</div>
</div>
@@ -415,6 +449,14 @@ export default function ModelPricingEditor({
onChange={(value) => handleNumericFieldChange('fixedPrice', value)}
extraText={t('适合 MJ / 任务类等按次收费模型。')}
/>
) : selectedModel.billingMode === 'tiered_expr' ? (
<TieredPricingEditor
model={selectedModel}
onExprChange={handleBillingExprChange}
requestRuleExpr={selectedModel.requestRuleExpr}
onRequestRuleExprChange={handleRequestRuleExprChange}
t={t}
/>
) : (
<>
<Card
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,443 @@
export const SOURCE_PARAM = 'param';
export const SOURCE_HEADER = 'header';
export const SOURCE_TIME = 'time';
export const MATCH_EQ = 'eq';
export const MATCH_CONTAINS = 'contains';
export const MATCH_GT = 'gt';
export const MATCH_GTE = 'gte';
export const MATCH_LT = 'lt';
export const MATCH_LTE = 'lte';
export const MATCH_EXISTS = 'exists';
export const MATCH_RANGE = 'range';
export const TIME_FUNCS = ['hour', 'minute', 'weekday', 'month', 'day'];
export const COMMON_TIMEZONES = [
{ value: 'Asia/Shanghai', label: 'CST (UTC+8 北京)' },
{ value: 'UTC', label: 'UTC' },
{ value: 'America/New_York', label: 'EST (UTC-5 纽约)' },
{ value: 'America/Los_Angeles', label: 'PST (UTC-8 洛杉矶)' },
{ value: 'America/Chicago', label: 'CST (UTC-6 芝加哥)' },
{ value: 'Europe/London', label: 'GMT (UTC+0 伦敦)' },
{ value: 'Europe/Berlin', label: 'CET (UTC+1 柏林)' },
{ value: 'Asia/Tokyo', label: 'JST (UTC+9 东京)' },
{ value: 'Asia/Singapore', label: 'SGT (UTC+8 新加坡)' },
{ value: 'Asia/Seoul', label: 'KST (UTC+9 首尔)' },
{ value: 'Australia/Sydney', label: 'AEST (UTC+10 悉尼)' },
];
export const NUMERIC_LITERAL_REGEX =
/^-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?$/;
// ---------------------------------------------------------------------------
// Condition creators (no multiplier — multiplier lives on the group)
// ---------------------------------------------------------------------------
export function createEmptyCondition() {
return { source: SOURCE_PARAM, path: '', mode: MATCH_EQ, value: '' };
}
export function createEmptyTimeCondition() {
return {
source: SOURCE_TIME,
timeFunc: 'hour',
timezone: 'Asia/Shanghai',
mode: MATCH_GTE,
value: '',
rangeStart: '',
rangeEnd: '',
};
}
// ---------------------------------------------------------------------------
// Group creators
// ---------------------------------------------------------------------------
export function createEmptyRuleGroup() {
return { conditions: [createEmptyCondition()], multiplier: '' };
}
export function createEmptyTimeRuleGroup() {
return { conditions: [createEmptyTimeCondition()], multiplier: '' };
}
// Kept for backward compat with old preset format
export function createEmptyRequestRule() {
return { source: SOURCE_PARAM, path: '', mode: MATCH_EQ, value: '', multiplier: '' };
}
export function createEmptyTimeRule() {
return {
source: SOURCE_TIME, timeFunc: 'hour', timezone: 'Asia/Shanghai',
mode: MATCH_GTE, value: '', rangeStart: '', rangeEnd: '', multiplier: '',
};
}
// ---------------------------------------------------------------------------
// Match options
// ---------------------------------------------------------------------------
export function getRequestRuleMatchOptions(source, t) {
if (source === SOURCE_TIME) {
return [
{ value: MATCH_EQ, label: t('等于') },
{ value: MATCH_GTE, label: t('大于等于') },
{ value: MATCH_LT, label: t('小于') },
{ value: MATCH_RANGE, label: t('跨夜范围') },
];
}
const base = [
{ value: MATCH_EQ, label: t('等于') },
{ value: MATCH_CONTAINS, label: t('包含') },
{ value: MATCH_EXISTS, label: t('存在') },
];
if (source === SOURCE_HEADER) {
return base;
}
return [
...base,
{ value: MATCH_GT, label: t('大于') },
{ value: MATCH_GTE, label: t('大于等于') },
{ value: MATCH_LT, label: t('小于') },
{ value: MATCH_LTE, label: t('小于等于') },
];
}
// ---------------------------------------------------------------------------
// Normalize a single condition
// ---------------------------------------------------------------------------
export function normalizeCondition(cond) {
const source = cond?.source === SOURCE_TIME
? SOURCE_TIME
: cond?.source === SOURCE_HEADER
? SOURCE_HEADER
: SOURCE_PARAM;
if (source === SOURCE_TIME) {
const timeFunc = TIME_FUNCS.includes(cond?.timeFunc) ? cond.timeFunc : 'hour';
const options = getRequestRuleMatchOptions(SOURCE_TIME, (v) => v);
const mode = options.some((item) => item.value === cond?.mode) ? cond.mode : MATCH_GTE;
return {
source: SOURCE_TIME,
timeFunc,
timezone: cond?.timezone || 'Asia/Shanghai',
mode,
value: cond?.value == null ? '' : String(cond.value),
rangeStart: cond?.rangeStart == null ? '' : String(cond.rangeStart),
rangeEnd: cond?.rangeEnd == null ? '' : String(cond.rangeEnd),
};
}
const options = getRequestRuleMatchOptions(source, (v) => v);
const mode = options.some((item) => item.value === cond?.mode) ? cond.mode : MATCH_EQ;
return {
source,
path: cond?.path || '',
mode,
value: cond?.value == null ? '' : String(cond.value),
};
}
// Legacy compat wrapper
export function normalizeRequestRule(rule) {
const base = normalizeCondition(rule);
return { ...base, multiplier: rule?.multiplier == null ? '' : String(rule.multiplier) };
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
export function splitTopLevelMultiply(expr) {
const parts = [];
let start = 0;
let depth = 0;
for (let index = 0; index < expr.length; index += 1) {
const char = expr[index];
if (char === '(') depth += 1;
if (char === ')') depth -= 1;
if (depth === 0 && expr.slice(index, index + 3) === ' * ') {
parts.push(expr.slice(start, index).trim());
start = index + 3;
index += 2;
}
}
parts.push(expr.slice(start).trim());
return parts.filter(Boolean);
}
function splitTopLevelAnd(expr) {
const parts = [];
let start = 0;
let depth = 0;
for (let i = 0; i < expr.length; i += 1) {
const c = expr[i];
if (c === '(') depth += 1;
if (c === ')') depth -= 1;
if (depth === 0 && expr.slice(i, i + 4) === ' && ') {
parts.push(expr.slice(start, i).trim());
start = i + 4;
i += 3;
}
}
parts.push(expr.slice(start).trim());
return parts.filter(Boolean);
}
function parseExprLiteral(raw) {
const text = raw.trim();
if (text === 'true' || text === 'false') return text;
if (NUMERIC_LITERAL_REGEX.test(text)) return text;
try { return JSON.parse(text); } catch { return null; }
}
function buildExprLiteral(mode, value) {
const text = String(value || '').trim();
if (mode === MATCH_CONTAINS) return JSON.stringify(text);
if (text === 'true' || text === 'false') return text;
if (NUMERIC_LITERAL_REGEX.test(text)) return text;
return JSON.stringify(text);
}
// ---------------------------------------------------------------------------
// Build a single condition expression string (no ? mult : 1 wrapper)
// ---------------------------------------------------------------------------
function buildTimeConditionExpr(cond) {
const normalized = normalizeCondition(cond);
const { timeFunc, timezone, mode } = normalized;
const tz = JSON.stringify(timezone);
const fn = `${timeFunc}(${tz})`;
if (mode === MATCH_RANGE) {
const s = normalized.rangeStart.trim();
const e = normalized.rangeEnd.trim();
if (!NUMERIC_LITERAL_REGEX.test(s) || !NUMERIC_LITERAL_REGEX.test(e)) return '';
return `${fn} >= ${s} || ${fn} < ${e}`;
}
const v = normalized.value.trim();
if (!NUMERIC_LITERAL_REGEX.test(v)) return '';
const opMap = { [MATCH_EQ]: '==', [MATCH_GTE]: '>=', [MATCH_LT]: '<' };
return `${fn} ${opMap[mode] || '=='} ${v}`;
}
function buildRequestConditionExpr(cond) {
if (cond?.source === SOURCE_TIME) return buildTimeConditionExpr(cond);
const normalized = normalizeCondition(cond);
const path = normalized.path.trim();
if (!path) return '';
const sourceExpr = normalized.source === SOURCE_HEADER
? `header(${JSON.stringify(path)})`
: `param(${JSON.stringify(path)})`;
switch (normalized.mode) {
case MATCH_EXISTS:
return normalized.source === SOURCE_HEADER
? `${sourceExpr} != ""`
: `${sourceExpr} != nil`;
case MATCH_CONTAINS:
return normalized.source === SOURCE_HEADER
? `has(${sourceExpr}, ${buildExprLiteral(normalized.mode, normalized.value)})`
: `${sourceExpr} != nil && has(${sourceExpr}, ${buildExprLiteral(normalized.mode, normalized.value)})`;
case MATCH_GT: case MATCH_GTE: case MATCH_LT: case MATCH_LTE: {
const opMap = { [MATCH_GT]: '>', [MATCH_GTE]: '>=', [MATCH_LT]: '<', [MATCH_LTE]: '<=' };
if (!NUMERIC_LITERAL_REGEX.test(String(normalized.value).trim())) return '';
return `${sourceExpr} != nil && ${sourceExpr} ${opMap[normalized.mode]} ${String(normalized.value).trim()}`;
}
case MATCH_EQ:
default:
return `${sourceExpr} == ${buildExprLiteral(normalized.mode, normalized.value)}`;
}
}
// ---------------------------------------------------------------------------
// Build a group factor: (cond1 && cond2 ? mult : 1)
// ---------------------------------------------------------------------------
function buildRuleGroupFactor(group) {
const multiplier = (group.multiplier || '').trim();
if (!NUMERIC_LITERAL_REGEX.test(multiplier)) return '';
const condExprs = (group.conditions || [])
.map(buildRequestConditionExpr)
.filter(Boolean);
if (condExprs.length === 0) return '';
const combined = condExprs.length === 1
? condExprs[0]
: condExprs.map((e) => (e.includes(' || ') ? `(${e})` : e)).join(' && ');
return `(${combined} ? ${multiplier} : 1)`;
}
export function buildRequestRuleExpr(groups) {
return (groups || []).map(buildRuleGroupFactor).filter(Boolean).join(' * ');
}
// ---------------------------------------------------------------------------
// Parse a single condition from an expression fragment
// ---------------------------------------------------------------------------
function tryParseTimeCondition(expr) {
// Range: hour("tz") >= s || hour("tz") < e
let m = expr.match(
/^(hour|minute|weekday|month|day)\("([^"]+)"\) >= ([\d.eE+-]+) \|\| \1\("\2"\) < ([\d.eE+-]+)$/,
);
if (m) {
return {
source: SOURCE_TIME, timeFunc: m[1], timezone: m[2],
mode: MATCH_RANGE, value: '', rangeStart: m[3], rangeEnd: m[4],
};
}
// Wrapped range: (hour("tz") >= s || hour("tz") < e)
m = expr.match(
/^\((hour|minute|weekday|month|day)\("([^"]+)"\) >= ([\d.eE+-]+) \|\| \1\("\2"\) < ([\d.eE+-]+)\)$/,
);
if (m) {
return {
source: SOURCE_TIME, timeFunc: m[1], timezone: m[2],
mode: MATCH_RANGE, value: '', rangeStart: m[3], rangeEnd: m[4],
};
}
// Simple: hour("tz") op value
m = expr.match(
/^(hour|minute|weekday|month|day)\("([^"]+)"\) (==|>=|<) ([\d.eE+-]+)$/,
);
if (m) {
const opMap = { '==': MATCH_EQ, '>=': MATCH_GTE, '<': MATCH_LT };
return {
source: SOURCE_TIME, timeFunc: m[1], timezone: m[2],
mode: opMap[m[3]] || MATCH_EQ, value: m[4], rangeStart: '', rangeEnd: '',
};
}
return null;
}
function tryParseRequestCondition(expr) {
const tc = tryParseTimeCondition(expr);
if (tc) return tc;
let m = expr.match(/^header\("([^"]+)"\) != ""$/);
if (m) return { source: SOURCE_HEADER, path: m[1], mode: MATCH_EXISTS, value: '' };
m = expr.match(/^param\("([^"]+)"\) != nil$/);
if (m) return { source: SOURCE_PARAM, path: m[1], mode: MATCH_EXISTS, value: '' };
m = expr.match(/^has\(header\("([^"]+)"\), ((?:"(?:[^"\\]|\\.)*"))\)$/);
if (m) return { source: SOURCE_HEADER, path: m[1], mode: MATCH_CONTAINS, value: JSON.parse(m[2]) };
m = expr.match(/^param\("([^"]+)"\) != nil && has\(param\("([^"]+)"\), ((?:"(?:[^"\\]|\\.)*"))\)$/);
if (m && m[1] === m[2]) return { source: SOURCE_PARAM, path: m[1], mode: MATCH_CONTAINS, value: JSON.parse(m[3]) };
m = expr.match(/^param\("([^"]+)"\) != nil && param\("([^"]+)"\) (>|>=|<|<=) ([\d.eE+-]+)$/);
if (m && m[1] === m[2]) {
const opMap = { '>': MATCH_GT, '>=': MATCH_GTE, '<': MATCH_LT, '<=': MATCH_LTE };
return { source: SOURCE_PARAM, path: m[1], mode: opMap[m[3]], value: m[4] };
}
m = expr.match(/^(param|header)\("([^"]+)"\) == (.+)$/);
if (m) {
const parsedValue = parseExprLiteral(m[3]);
if (parsedValue === null) return null;
return { source: m[1], path: m[2], mode: MATCH_EQ, value: String(parsedValue) };
}
return null;
}
// ---------------------------------------------------------------------------
// Parse a group factor: (cond1 && cond2 ? mult : 1)
// ---------------------------------------------------------------------------
function tryParseRuleGroupFactor(part) {
// Must be wrapped in ( ... ? mult : 1)
const m = part.match(/^\((.+) \? ([\d.eE+-]+) : 1\)$/s);
if (!m) return null;
const conditionStr = m[1];
const multiplier = m[2];
const andParts = splitTopLevelAnd(conditionStr);
const conditions = [];
for (const ap of andParts) {
const cond = tryParseRequestCondition(ap.trim());
if (!cond) return null;
conditions.push(normalizeCondition(cond));
}
if (conditions.length === 0) return null;
return { conditions, multiplier };
}
export function tryParseRequestRuleExpr(expr) {
const trimmed = (expr || '').trim();
if (!trimmed) return [];
const parts = splitTopLevelMultiply(trimmed);
const groups = [];
for (const part of parts) {
const group = tryParseRuleGroupFactor(part);
if (!group) return null;
groups.push(group);
}
return groups;
}
// ---------------------------------------------------------------------------
// Combine / split billing expr and request rules
// ---------------------------------------------------------------------------
function hasFullOuterParens(expr) {
if (!expr.startsWith('(') || !expr.endsWith(')')) return false;
let depth = 0;
for (let i = 0; i < expr.length; i += 1) {
if (expr[i] === '(') depth += 1;
if (expr[i] === ')') depth -= 1;
if (depth === 0 && i < expr.length - 1) return false;
}
return depth === 0;
}
export function unwrapOuterParens(expr) {
let current = (expr || '').trim();
while (hasFullOuterParens(current)) {
current = current.slice(1, -1).trim();
}
return current;
}
export function combineBillingExpr(baseExpr, requestRuleExpr) {
const base = (baseExpr || '').trim();
const rules = (requestRuleExpr || '').trim();
if (!base) return '';
if (!rules) return base;
return `(${base}) * ${rules}`;
}
export function splitBillingExprAndRequestRules(expr) {
const trimmed = (expr || '').trim();
if (!trimmed) return { billingExpr: '', requestRuleExpr: '' };
const parts = splitTopLevelMultiply(trimmed);
if (parts.length <= 1) return { billingExpr: trimmed, requestRuleExpr: '' };
const ruleParts = [];
const baseParts = [];
parts.forEach((part) => {
if (tryParseRequestRuleExpr(part) !== null && tryParseRequestRuleExpr(part).length > 0) {
ruleParts.push(part);
} else {
baseParts.push(part);
}
});
if (ruleParts.length === 0 || baseParts.length !== 1) {
return { billingExpr: trimmed, requestRuleExpr: '' };
}
return {
billingExpr: unwrapOuterParens(baseParts[0]),
requestRuleExpr: ruleParts.join(' * '),
};
}
@@ -1,5 +1,9 @@
import { useEffect, useMemo, useState } from 'react';
import { API, showError, showSuccess } from '../../../../helpers';
import {
combineBillingExpr,
splitBillingExprAndRequestRules,
} from '../components/requestRuleExpr';
export const PAGE_SIZE = 10;
export const PRICE_SUFFIX = '$/1M tokens';
@@ -18,6 +22,8 @@ const EMPTY_MODEL = {
imagePrice: '',
audioInputPrice: '',
audioOutputPrice: '',
billingExpr: '',
requestRuleExpr: '',
rawRatios: {
modelRatio: '',
completionRatio: '',
@@ -98,6 +104,22 @@ const normalizeCompletionRatioMeta = (rawMeta) => {
};
const buildModelState = (name, sourceMaps) => {
const billingMode = sourceMaps.ModelBillingMode?.[name];
if (billingMode === 'tiered_expr') {
const fullBillingExpr = sourceMaps.ModelBillingExpr?.[name] || '';
const { billingExpr, requestRuleExpr } =
splitBillingExprAndRequestRules(fullBillingExpr);
return {
...EMPTY_MODEL,
name,
billingMode: 'tiered_expr',
billingExpr,
requestRuleExpr,
rawRatios: { ...EMPTY_MODEL.rawRatios },
hasConflict: false,
};
}
const modelRatio = toNumericString(sourceMaps.ModelRatio[name]);
const completionRatio = toNumericString(sourceMaps.CompletionRatio[name]);
const completionRatioMeta = normalizeCompletionRatioMeta(
@@ -159,6 +181,7 @@ const buildModelState = (name, sourceMaps) => {
toNumberOrNull(audioInputPrice) !== null && hasValue(audioCompletionRatio)
? formatNumber(Number(audioInputPrice) * Number(audioCompletionRatio))
: '',
requestRuleExpr: '',
rawRatios: {
modelRatio,
completionRatio,
@@ -183,12 +206,16 @@ const buildModelState = (name, sourceMaps) => {
};
export const isBasePricingUnset = (model) =>
model.billingMode !== 'tiered_expr' &&
!hasValue(model.fixedPrice) && !hasValue(model.inputPrice);
export const getModelWarnings = (model, t) => {
if (!model) {
return [];
}
if (model.billingMode === 'tiered_expr') {
return [];
}
const warnings = [];
const hasDerivedPricing = [
model.inputPrice,
@@ -244,8 +271,22 @@ export const getModelWarnings = (model, t) => {
};
export const buildSummaryText = (model, t) => {
const requestRuleSuffix =
model.billingMode === 'tiered_expr' && model.requestRuleExpr
? `${t('请求规则')}`
: '';
if (model.billingMode === 'tiered_expr') {
const expr = model.billingExpr;
if (!expr) return `${t('表达式计费')}${requestRuleSuffix}`;
const tierCount = (expr.match(/tier\(/g) || []).length;
if (tierCount === 0) {
return `${t('表达式计费')}${requestRuleSuffix}`;
}
return `${t('阶梯计费')} (${tierCount} ${t('档')})${requestRuleSuffix}`;
}
if (model.billingMode === 'per-request' && hasValue(model.fixedPrice)) {
return `${t('按次')} $${model.fixedPrice} / ${t('次')}`;
return `${t('按次')} $${model.fixedPrice} / ${t('次')}${requestRuleSuffix}`;
}
if (hasValue(model.inputPrice)) {
@@ -259,10 +300,10 @@ export const buildSummaryText = (model, t) => {
].filter(hasValue).length;
const extraLabel =
extraCount > 0 ? `${t('额外价格项')} ${extraCount}` : '';
return `${t('输入')} $${model.inputPrice}${extraLabel}`;
return `${t('输入')} $${model.inputPrice}${extraLabel}${requestRuleSuffix}`;
}
return t('未设置价格');
return `${t('未设置价格')}${requestRuleSuffix}`;
};
export const buildOptionalFieldToggles = (model) => ({
@@ -395,20 +436,53 @@ const serializeModel = (model, t) => {
export const buildPreviewRows = (model, t) => {
if (!model) return [];
const finalBillingExpr = combineBillingExpr(
model.billingExpr,
model.requestRuleExpr,
);
if (model.billingMode === 'tiered_expr') {
const rows = [
{
key: 'BillingMode',
label: 'ModelBillingMode',
value: 'tiered_expr',
},
];
if (finalBillingExpr) {
const tierCount = (model.billingExpr.match(/tier\(/g) || []).length;
rows.push({
key: 'BillingExpr',
label: 'ModelBillingExpr',
value:
tierCount > 0
? `${tierCount} ${t('档')}${
finalBillingExpr.length > 60
? finalBillingExpr.slice(0, 60) + '...'
: finalBillingExpr
}`
: finalBillingExpr.length > 60
? finalBillingExpr.slice(0, 60) + '...'
: finalBillingExpr,
});
}
return rows;
}
if (model.billingMode === 'per-request') {
return [
const rows = [
{
key: 'ModelPrice',
label: 'ModelPrice',
value: hasValue(model.fixedPrice) ? model.fixedPrice : t('空'),
},
];
return rows;
}
const inputPrice = toNumberOrNull(model.inputPrice);
if (inputPrice === null) {
return [
const rows = [
{
key: 'ModelRatio',
label: 'ModelRatio',
@@ -459,6 +533,7 @@ export const buildPreviewRows = (model, t) => {
: t('空'),
},
];
return rows;
}
const completionPrice = toNumberOrNull(model.completionPrice);
@@ -468,7 +543,7 @@ export const buildPreviewRows = (model, t) => {
const audioInputPrice = toNumberOrNull(model.audioInputPrice);
const audioOutputPrice = toNumberOrNull(model.audioOutputPrice);
return [
const rows = [
{
key: 'ModelRatio',
label: 'ModelRatio',
@@ -522,6 +597,7 @@ export const buildPreviewRows = (model, t) => {
: t('空'),
},
];
return rows;
};
export function useModelPricingEditorState({
@@ -552,6 +628,8 @@ export function useModelPricingEditorState({
ImageRatio: parseOptionJSON(options.ImageRatio),
AudioRatio: parseOptionJSON(options.AudioRatio),
AudioCompletionRatio: parseOptionJSON(options.AudioCompletionRatio),
ModelBillingMode: parseOptionJSON(options.ModelBillingMode),
ModelBillingExpr: parseOptionJSON(options.ModelBillingExpr),
};
const names = new Set([
@@ -565,6 +643,8 @@ export function useModelPricingEditorState({
...Object.keys(sourceMaps.ImageRatio),
...Object.keys(sourceMaps.AudioRatio),
...Object.keys(sourceMaps.AudioCompletionRatio),
...Object.keys(sourceMaps.ModelBillingMode),
...Object.keys(sourceMaps.ModelBillingExpr),
]);
const nextModels = Array.from(names)
@@ -775,10 +855,29 @@ export function useModelPricingEditorState({
};
const handleBillingModeChange = (value) => {
if (!selectedModel) return;
upsertModel(selectedModel.name, (model) => {
const next = { ...model, billingMode: value };
if (value === 'tiered_expr' && !model.billingExpr) {
next.billingExpr = 'tier("default", p * 0 + c * 0)';
}
return next;
});
};
const handleBillingExprChange = (newExpr) => {
if (!selectedModel) return;
upsertModel(selectedModel.name, (model) => ({
...model,
billingMode: value,
billingExpr: newExpr,
}));
};
const handleRequestRuleExprChange = (newExpr) => {
if (!selectedModel) return;
upsertModel(selectedModel.name, (model) => ({
...model,
requestRuleExpr: newExpr,
}));
};
@@ -854,6 +953,8 @@ export function useModelPricingEditorState({
imagePrice: selectedModel.imagePrice,
audioInputPrice: selectedModel.audioInputPrice,
audioOutputPrice: selectedModel.audioOutputPrice,
billingExpr: selectedModel.billingExpr || '',
requestRuleExpr: selectedModel.requestRuleExpr || '',
};
if (
@@ -915,7 +1016,26 @@ export function useModelPricingEditorState({
AudioCompletionRatio: {},
};
const tieredOutput = {
ModelBillingMode: {},
ModelBillingExpr: {},
};
for (const model of models) {
if (model.billingMode === 'tiered_expr') {
tieredOutput.ModelBillingMode[model.name] = 'tiered_expr';
const finalBillingExpr = combineBillingExpr(
model.billingExpr,
model.requestRuleExpr,
);
if (finalBillingExpr) {
tieredOutput.ModelBillingExpr[model.name] = finalBillingExpr;
}
}
if (model.billingMode === 'tiered_expr') {
continue;
}
const serialized = serializeModel(model, t);
Object.entries(serialized).forEach(([key, value]) => {
if (value !== null) {
@@ -924,12 +1044,20 @@ export function useModelPricingEditorState({
});
}
const requestQueue = Object.entries(output).map(([key, value]) =>
API.put('/api/option/', {
key,
value: JSON.stringify(value, null, 2),
}),
);
const requestQueue = [
...Object.entries(output).map(([key, value]) =>
API.put('/api/option/', {
key,
value: JSON.stringify(value, null, 2),
}),
),
...Object.entries(tieredOutput).map(([key, value]) =>
API.put('/api/option/', {
key,
value: JSON.stringify(value, null, 2),
}),
),
];
const results = await Promise.all(requestQueue);
for (const res of results) {
@@ -970,6 +1098,8 @@ export function useModelPricingEditorState({
handleOptionalFieldToggle,
handleNumericFieldChange,
handleBillingModeChange,
handleBillingExprChange,
handleRequestRuleExprChange,
handleSubmit,
addModel,
deleteModel,