Merge branch 'origin/main' into nightly

Resolve 4 conflicts:
- relay/compatible_handler.go: accept main's refactor (postConsumeQuota -> service.PostTextConsumeQuota)
- service/quota.go: accept main's PostClaudeConsumeQuota deletion, keep nightly's tiered billing in PostWssConsumeQuota and PostAudioConsumeQuota
- web/src/i18n/locales/{en,zh-CN}.json: merge both sets of translation keys

Post-merge integration:
- Add tiered billing (TryTieredSettle, InjectTieredBillingInfo) to PostTextConsumeQuota
- Update tool pricing calls to use nightly's generic GetToolPriceForModel/GetToolPrice API
This commit is contained in:
CaIon
2026-04-02 00:39:13 +08:00
122 changed files with 8307 additions and 3923 deletions
+1 -1
View File
@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
postConsumeQuota(c, info, usage.(*dto.Usage))
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
}
return nil
+11 -6
View File
@@ -171,12 +171,17 @@ type AliImageRequest struct {
}
type AliImageParameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
PromptExtend *bool `json:"prompt_extend,omitempty"`
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
PromptExtend *bool `json:"prompt_extend,omitempty"`
ThinkingMode *bool `json:"thinking_mode,omitempty"`
EnableSequential *bool `json:"enable_sequential,omitempty"`
BboxList any `json:"bbox_list,omitempty"`
ColorPalette any `json:"color_palette,omitempty"`
Seed *int `json:"seed,omitempty"`
}
func (p *AliImageParameters) PromptExtendValue() bool {
+1 -2
View File
@@ -54,7 +54,6 @@ func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequ
}
}
// 检查n参数
if imageRequest.Parameters.N != 0 {
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
}
@@ -181,6 +180,7 @@ func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, reque
},
}
imageRequest.Parameters = AliImageParameters{
N: int(lo.FromPtrOr(request.N, uint(1))),
Watermark: request.Watermark,
}
return &imageRequest, nil
@@ -328,7 +328,6 @@ func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *rela
}
imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
// 可能生成多张图片,修正计费数量n
if aliResponse.Usage.ImageCount != 0 {
info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
} else if len(imageResponses.Data) != 0 {
+2 -1
View File
@@ -40,7 +40,8 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ
}
func isOldWanModel(modelName string) bool {
return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6")
return strings.Contains(modelName, "wan") &&
!lo.SomeBy([]string{"wan2.6", "wan2.7"}, func(v string) bool { return strings.Contains(modelName, v) })
}
func isWanModel(modelName string) bool {
+6 -7
View File
@@ -116,12 +116,12 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
usage := &dto.Usage{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
var baiduResponse BaiduChatStreamResponse
err := common.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
if err := common.Unmarshal([]byte(data), &baiduResponse); err != nil {
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
sr.Error(err)
return
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
@@ -129,11 +129,10 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := streamResponseBaidu2OpenAI(&baiduResponse)
err = helper.ObjectData(c, response)
if err != nil {
if err := helper.ObjectData(c, response); err != nil {
common.SysLog("error sending stream response: " + err.Error())
sr.Error(err)
}
return true
})
service.CloseResponseBodyGracefully(resp)
return nil, usage
+26 -4
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
@@ -41,11 +42,32 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
baseURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
if info.IsClaudeBetaQuery {
baseURL = baseURL + "?beta=true"
requestURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
if !shouldAppendClaudeBetaQuery(info) {
return requestURL, nil
}
return baseURL, nil
parsedURL, err := url.Parse(requestURL)
if err != nil {
return "", err
}
query := parsedURL.Query()
query.Set("beta", "true")
parsedURL.RawQuery = query.Encode()
return parsedURL.String(), nil
}
func shouldAppendClaudeBetaQuery(info *relaycommon.RelayInfo) bool {
if info == nil {
return false
}
if info.IsClaudeBetaQuery {
return true
}
if info.ChannelOtherSettings.ClaudeBetaQuery {
return true
}
return false
}
func CommonClaudeHeadersOperation(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) {
+40 -5
View File
@@ -555,6 +555,35 @@ type ClaudeResponseInfo struct {
Done bool
}
func cacheCreationTokensForOpenAIUsage(usage *dto.Usage) int {
if usage == nil {
return 0
}
splitCacheCreationTokens := usage.ClaudeCacheCreation5mTokens + usage.ClaudeCacheCreation1hTokens
if splitCacheCreationTokens == 0 {
return usage.PromptTokensDetails.CachedCreationTokens
}
if usage.PromptTokensDetails.CachedCreationTokens > splitCacheCreationTokens {
return usage.PromptTokensDetails.CachedCreationTokens
}
return splitCacheCreationTokens
}
func buildOpenAIStyleUsageFromClaudeUsage(usage *dto.Usage) dto.Usage {
if usage == nil {
return dto.Usage{}
}
clone := *usage
cacheCreationTokens := cacheCreationTokensForOpenAIUsage(usage)
totalInputTokens := usage.PromptTokens + usage.PromptTokensDetails.CachedTokens + cacheCreationTokens
clone.PromptTokens = totalInputTokens
clone.InputTokens = totalInputTokens
clone.TotalTokens = totalInputTokens + usage.CompletionTokens
clone.UsageSemantic = "openai"
clone.UsageSource = "anthropic"
return clone
}
func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage {
usage := &dto.ClaudeUsage{}
if claudeResponse != nil && claudeResponse.Usage != nil {
@@ -643,6 +672,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
// message_start, 获取usage
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.UsageSemantic = "anthropic"
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
@@ -661,6 +691,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
} else if claudeResponse.Type == "message_delta" {
// 最终的usage获取
if claudeResponse.Usage != nil {
claudeInfo.Usage.UsageSemantic = "anthropic"
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
@@ -754,12 +785,16 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
}
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
if claudeInfo.Usage != nil {
claudeInfo.Usage.UsageSemantic = "anthropic"
}
if info.RelayFormat == types.RelayFormatClaude {
//
} else if info.RelayFormat == types.RelayFormatOpenAI {
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(claudeInfo.Usage)
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, openAIUsage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysLog("send final response failed: " + err.Error())
@@ -778,12 +813,11 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
Usage: &dto.Usage{},
}
var err *types.NewAPIError
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
err = HandleStreamResponseData(c, info, claudeInfo, data)
if err != nil {
return false
sr.Stop(err)
}
return true
})
if err != nil {
return nil, err
@@ -810,6 +844,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.UsageSemantic = "anthropic"
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
@@ -819,7 +854,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
switch info.RelayFormat {
case types.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(&claudeResponse)
openaiResponse.Usage = *claudeInfo.Usage
openaiResponse.Usage = buildOpenAIStyleUsageFromClaudeUsage(claudeInfo.Usage)
responseData, err = json.Marshal(openaiResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody)
+82
View File
@@ -173,3 +173,85 @@ func TestFormatClaudeResponseInfo_ContentBlockDelta(t *testing.T) {
t.Errorf("ResponseText = %q, want %q", claudeInfo.ResponseText.String(), "hello")
}
}
func TestBuildOpenAIStyleUsageFromClaudeUsage(t *testing.T) {
usage := &dto.Usage{
PromptTokens: 100,
CompletionTokens: 20,
PromptTokensDetails: dto.InputTokenDetails{
CachedTokens: 30,
CachedCreationTokens: 50,
},
ClaudeCacheCreation5mTokens: 10,
ClaudeCacheCreation1hTokens: 20,
UsageSemantic: "anthropic",
}
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
if openAIUsage.PromptTokens != 180 {
t.Fatalf("PromptTokens = %d, want 180", openAIUsage.PromptTokens)
}
if openAIUsage.InputTokens != 180 {
t.Fatalf("InputTokens = %d, want 180", openAIUsage.InputTokens)
}
if openAIUsage.TotalTokens != 200 {
t.Fatalf("TotalTokens = %d, want 200", openAIUsage.TotalTokens)
}
if openAIUsage.UsageSemantic != "openai" {
t.Fatalf("UsageSemantic = %s, want openai", openAIUsage.UsageSemantic)
}
if openAIUsage.UsageSource != "anthropic" {
t.Fatalf("UsageSource = %s, want anthropic", openAIUsage.UsageSource)
}
}
func TestBuildOpenAIStyleUsageFromClaudeUsagePreservesCacheCreationRemainder(t *testing.T) {
tests := []struct {
name string
cachedCreationTokens int
cacheCreationTokens5m int
cacheCreationTokens1h int
expectedTotalInputToken int
}{
{
name: "prefers aggregate when it includes remainder",
cachedCreationTokens: 50,
cacheCreationTokens5m: 10,
cacheCreationTokens1h: 20,
expectedTotalInputToken: 180,
},
{
name: "falls back to split tokens when aggregate missing",
cachedCreationTokens: 0,
cacheCreationTokens5m: 10,
cacheCreationTokens1h: 20,
expectedTotalInputToken: 160,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
usage := &dto.Usage{
PromptTokens: 100,
CompletionTokens: 20,
PromptTokensDetails: dto.InputTokenDetails{
CachedTokens: 30,
CachedCreationTokens: tt.cachedCreationTokens,
},
ClaudeCacheCreation5mTokens: tt.cacheCreationTokens5m,
ClaudeCacheCreation1hTokens: tt.cacheCreationTokens1h,
UsageSemantic: "anthropic",
}
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
if openAIUsage.PromptTokens != tt.expectedTotalInputToken {
t.Fatalf("PromptTokens = %d, want %d", openAIUsage.PromptTokens, tt.expectedTotalInputToken)
}
if openAIUsage.InputTokens != tt.expectedTotalInputToken {
t.Fatalf("InputTokens = %d, want %d", openAIUsage.InputTokens, tt.expectedTotalInputToken)
}
})
}
}
+16 -17
View File
@@ -223,33 +223,32 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
usage := &dto.Usage{}
var nodeToken int
helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil {
if err := json.Unmarshal([]byte(data), &difyResponse); err != nil {
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
sr.Error(err)
return
}
var openaiResponse dto.ChatCompletionsStreamResponse
if difyResponse.Event == "message_end" {
usage = &difyResponse.MetaData.Usage
return false
sr.Done()
return
} else if difyResponse.Event == "error" {
return false
} else {
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
if len(openaiResponse.Choices) != 0 {
responseText += openaiResponse.Choices[0].Delta.GetContentString()
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
nodeToken += 1
}
sr.Stop(fmt.Errorf("dify error event"))
return
}
openaiResponse := *streamResponseDify2OpenAI(difyResponse)
if len(openaiResponse.Choices) != 0 {
responseText += openaiResponse.Choices[0].Delta.GetContentString()
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
nodeToken += 1
}
}
err = helper.ObjectData(c, openaiResponse)
if err != nil {
if err := helper.ObjectData(c, openaiResponse); err != nil {
common.SysLog(err.Error())
sr.Error(err)
}
return true
})
helper.Done(c)
if usage.TotalTokens == 0 {
+7 -6
View File
@@ -1305,12 +1305,11 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
var imageCount int
responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
if err := common.UnmarshalJsonStr(data, &geminiResponse); err != nil {
sr.Stop(fmt.Errorf("unmarshal: %w", err))
return
}
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
@@ -1335,7 +1334,9 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
*usage = mappedUsage
}
return callback(data, &geminiResponse)
if !callback(data, &geminiResponse) {
sr.Stop(fmt.Errorf("gemini callback stopped"))
}
})
if imageCount != 0 {
+7 -7
View File
@@ -35,21 +35,21 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
c.Writer.WriteHeader(resp.StatusCode)
if info.IsStream {
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
if service.SundaySearch(data, "usage") {
var simpleResponse dto.SimpleResponse
err := common.Unmarshal([]byte(data), &simpleResponse)
if err != nil {
if err := common.Unmarshal([]byte(data), &simpleResponse); err != nil {
logger.LogError(c, err.Error())
}
if simpleResponse.Usage.TotalTokens != 0 {
sr.Error(err)
} else if simpleResponse.Usage.TotalTokens != 0 {
usage.PromptTokens = simpleResponse.Usage.InputTokens
usage.CompletionTokens = simpleResponse.OutputTokens
usage.TotalTokens = simpleResponse.TotalTokens
}
}
_ = helper.StringData(c, data)
return true
if err := helper.StringData(c, data); err != nil {
sr.Error(err)
}
})
} else {
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
+27 -16
View File
@@ -296,15 +296,17 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
return true
}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
if streamErr != nil {
return false
sr.Stop(streamErr)
return
}
var streamResp dto.ResponsesStreamResponse
if err := common.UnmarshalJsonStr(data, &streamResp); err != nil {
logger.LogError(c, "failed to unmarshal responses stream event: "+err.Error())
return true
sr.Error(err)
return
}
switch streamResp.Type {
@@ -320,14 +322,16 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
//case "response.reasoning_text.delta":
//if !sendReasoningDelta(streamResp.Delta) {
// return false
// sr.Stop(streamErr)
// return
//}
//case "response.reasoning_text.done":
case "response.reasoning_summary_text.delta":
if !sendReasoningSummaryDelta(streamResp.Delta) {
return false
sr.Stop(streamErr)
return
}
case "response.reasoning_summary_text.done":
@@ -349,12 +353,14 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
// delta := stringDeltaFromPrefix(prev, next)
// reasoningSummaryTextByKey[key] = next
// if !sendReasoningSummaryDelta(delta) {
// return false
// sr.Stop(streamErr)
// return
// }
case "response.output_text.delta":
if !sendStartIfNeeded() {
return false
sr.Stop(streamErr)
return
}
if streamResp.Delta != "" {
@@ -376,7 +382,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
},
}
if !sendChatChunk(chunk) {
return false
sr.Stop(streamErr)
return
}
}
@@ -414,7 +421,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
}
if !sendToolCallDelta(callID, name, argsDelta) {
return false
sr.Stop(streamErr)
return
}
case "response.function_call_arguments.delta":
@@ -428,7 +436,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
}
toolCallArgsByID[callID] += streamResp.Delta
if !sendToolCallDelta(callID, "", streamResp.Delta) {
return false
sr.Stop(streamErr)
return
}
case "response.function_call_arguments.done":
@@ -467,7 +476,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
}
if !sendStartIfNeeded() {
return false
sr.Stop(streamErr)
return
}
if !sentStop {
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil {
@@ -479,7 +489,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
}
stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
if !sendChatChunk(stop) {
return false
sr.Stop(streamErr)
return
}
sentStop = true
}
@@ -488,16 +499,16 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
if streamResp.Response != nil {
if oaiErr := streamResp.Response.GetOpenAIError(); oaiErr != nil && oaiErr.Type != "" {
streamErr = types.WithOpenAIError(*oaiErr, http.StatusInternalServerError)
return false
sr.Stop(streamErr)
return
}
}
streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError)
return false
sr.Stop(streamErr)
return
default:
}
return true
})
if streamErr != nil {
+31 -4
View File
@@ -126,11 +126,11 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
// 检查是否为音频模型
isAudioModel := strings.Contains(strings.ToLower(model), "audio")
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
if lastStreamData != "" {
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
if err != nil {
if err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent); err != nil {
common.SysLog("error handling stream format: " + err.Error())
sr.Error(err)
}
}
if len(data) > 0 {
@@ -142,7 +142,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
lastStreamData = data
streamItems = append(streamItems, data)
}
return true
})
// 对音频模型,从倒数第二个stream data中提取usage信息
@@ -627,6 +626,12 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
case constant.ChannelTypeOpenAI:
if usage.PromptTokensDetails.CachedTokens == 0 {
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
}
}
}
}
@@ -689,3 +694,25 @@ func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
return 0, false
}
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
if len(body) == 0 {
return 0, false
}
var payload struct {
Timings struct {
CachedTokens *int `json:"cache_n"`
} `json:"timings"`
}
if err := common.Unmarshal(body, &payload); err != nil {
return 0, false
}
if payload.Timings.CachedTokens == nil {
return 0, false
}
return *payload.Timings.CachedTokens, true
}
+38 -38
View File
@@ -79,55 +79,55 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
var usage = &dto.Usage{}
var responseTextBuilder strings.Builder
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
// 检查当前数据是否包含 completed 状态和 usage 信息
var streamResponse dto.ResponsesStreamResponse
if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
if streamResponse.Response != nil {
if streamResponse.Response.Usage != nil {
if streamResponse.Response.Usage.InputTokens != 0 {
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
}
if streamResponse.Response.Usage.OutputTokens != 0 {
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
}
if streamResponse.Response.Usage.TotalTokens != 0 {
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
}
if streamResponse.Response.Usage.InputTokensDetails != nil {
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
}
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
sr.Error(err)
return
}
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
if streamResponse.Response != nil {
if streamResponse.Response.Usage != nil {
if streamResponse.Response.Usage.InputTokens != 0 {
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
}
if streamResponse.Response.HasImageGenerationCall() {
c.Set("image_generation_call", true)
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
if streamResponse.Response.Usage.OutputTokens != 0 {
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
}
if streamResponse.Response.Usage.TotalTokens != 0 {
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
}
if streamResponse.Response.Usage.InputTokensDetails != nil {
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
}
}
case "response.output_text.delta":
// 处理输出文本
responseTextBuilder.WriteString(streamResponse.Delta)
case dto.ResponsesOutputTypeItemDone:
// 函数调用处理
if streamResponse.Item != nil {
switch streamResponse.Item.Type {
case dto.BuildInCallWebSearchCall:
if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
webSearchTool.CallCount++
}
if streamResponse.Response.HasImageGenerationCall() {
c.Set("image_generation_call", true)
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
}
}
case "response.output_text.delta":
// 处理输出文本
responseTextBuilder.WriteString(streamResponse.Delta)
case dto.ResponsesOutputTypeItemDone:
// 函数调用处理
if streamResponse.Item != nil {
switch streamResponse.Item.Type {
case dto.BuildInCallWebSearchCall:
if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
webSearchTool.CallCount++
}
}
}
}
} else {
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
}
return true
})
if usage.CompletionTokens == 0 {
+2
View File
@@ -17,6 +17,8 @@ func UnmarshalMetadata(metadata map[string]any, target any) error {
if metadata == nil {
return nil
}
// Prevent metadata from overriding model fields to avoid billing bypass.
delete(metadata, "model")
metaBytes, err := common.Marshal(metadata)
if err != nil {
return fmt.Errorf("marshal metadata failed: %w", err)
+1 -1
View File
@@ -76,7 +76,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if strings.HasPrefix(request.Model, "grok-3-mini") {
if lo.FromPtrOr(request.MaxCompletionTokens, uint(0)) == 0 && lo.FromPtrOr(request.MaxTokens, uint(0)) != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = lo.ToPtr(uint(0))
request.MaxTokens = nil
}
if strings.HasSuffix(request.Model, "-high") {
request.ReasoningEffort = "high"
+6 -7
View File
@@ -43,12 +43,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
var xAIResp *dto.ChatCompletionsStreamResponse
err := common.UnmarshalJsonStr(data, &xAIResp)
if err != nil {
if err := common.UnmarshalJsonStr(data, &xAIResp); err != nil {
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
sr.Error(err)
return
}
// 把 xAI 的usage转换为 OpenAI 的usage
@@ -61,11 +61,10 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
if err := helper.ObjectData(c, openaiResponse); err != nil {
common.SysLog(err.Error())
sr.Error(err)
}
return true
})
if !containStreamUsage {
+2 -2
View File
@@ -122,7 +122,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
return newApiErr
}
service.PostClaudeConsumeQuota(c, info, usage)
service.PostTextConsumeQuota(c, info, usage, nil)
return nil
}
@@ -190,6 +190,6 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
return newAPIError
}
service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
return nil
}
+231 -3
View File
@@ -21,10 +21,23 @@ var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
const (
paramOverrideContextRequestHeaders = "request_headers"
paramOverrideContextHeaderOverride = "header_override"
paramOverrideContextAuditRecorder = "__param_override_audit_recorder"
)
var errSourceHeaderNotFound = errors.New("source header does not exist")
var paramOverrideKeyAuditPaths = map[string]struct{}{
"model": {},
"original_model": {},
"upstream_model": {},
"service_tier": {},
"inference_geo": {},
}
type paramOverrideAuditRecorder struct {
lines []string
}
type ConditionOperation struct {
Path string `json:"path"` // JSON路径
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
@@ -118,6 +131,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
if len(paramOverride) == 0 {
return jsonData, nil
}
auditRecorder := getParamOverrideAuditRecorder(conditionContext)
// 尝试断言为操作格式
if operations, ok := tryParseOperations(paramOverride); ok {
@@ -125,7 +139,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
workingJSON := jsonData
var err error
if len(legacyOverride) > 0 {
workingJSON, err = applyOperationsLegacy(workingJSON, legacyOverride)
workingJSON, err = applyOperationsLegacy(workingJSON, legacyOverride, auditRecorder)
if err != nil {
return nil, err
}
@@ -137,7 +151,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
}
// 直接使用旧方法
return applyOperationsLegacy(jsonData, paramOverride)
return applyOperationsLegacy(jsonData, paramOverride, auditRecorder)
}
func buildLegacyParamOverride(paramOverride map[string]interface{}) map[string]interface{} {
@@ -161,14 +175,200 @@ func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte,
}
overrideCtx := BuildParamOverrideContext(info)
var recorder *paramOverrideAuditRecorder
if shouldEnableParamOverrideAudit(paramOverride) {
recorder = &paramOverrideAuditRecorder{}
overrideCtx[paramOverrideContextAuditRecorder] = recorder
}
result, err := ApplyParamOverride(jsonData, paramOverride, overrideCtx)
if err != nil {
return nil, err
}
syncRuntimeHeaderOverrideFromContext(info, overrideCtx)
if info != nil {
if recorder != nil {
info.ParamOverrideAudit = recorder.lines
} else {
info.ParamOverrideAudit = nil
}
}
return result, nil
}
func shouldEnableParamOverrideAudit(paramOverride map[string]interface{}) bool {
if common.DebugEnabled {
return true
}
if len(paramOverride) == 0 {
return false
}
if operations, ok := tryParseOperations(paramOverride); ok {
for _, operation := range operations {
if shouldAuditParamPath(strings.TrimSpace(operation.Path)) ||
shouldAuditParamPath(strings.TrimSpace(operation.To)) {
return true
}
}
for key := range buildLegacyParamOverride(paramOverride) {
if shouldAuditParamPath(strings.TrimSpace(key)) {
return true
}
}
return false
}
for key := range paramOverride {
if shouldAuditParamPath(strings.TrimSpace(key)) {
return true
}
}
return false
}
func getParamOverrideAuditRecorder(context map[string]interface{}) *paramOverrideAuditRecorder {
if context == nil {
return nil
}
recorder, _ := context[paramOverrideContextAuditRecorder].(*paramOverrideAuditRecorder)
return recorder
}
func (r *paramOverrideAuditRecorder) recordOperation(mode, path, from, to string, value interface{}) {
if r == nil {
return
}
line := buildParamOverrideAuditLine(mode, path, from, to, value)
if line == "" {
return
}
if lo.Contains(r.lines, line) {
return
}
r.lines = append(r.lines, line)
}
func shouldAuditParamPath(path string) bool {
path = strings.TrimSpace(path)
if path == "" {
return false
}
if common.DebugEnabled {
return true
}
_, ok := paramOverrideKeyAuditPaths[path]
return ok
}
func shouldAuditOperation(mode, path, from, to string) bool {
if common.DebugEnabled {
return true
}
for _, candidate := range []string{path, to} {
if shouldAuditParamPath(candidate) {
return true
}
}
return false
}
func formatParamOverrideAuditValue(value interface{}) string {
switch typed := value.(type) {
case nil:
return "<empty>"
case string:
return typed
default:
return common.GetJsonString(typed)
}
}
func buildParamOverrideAuditLine(mode, path, from, to string, value interface{}) string {
mode = strings.TrimSpace(mode)
path = strings.TrimSpace(path)
from = strings.TrimSpace(from)
to = strings.TrimSpace(to)
if !shouldAuditOperation(mode, path, from, to) {
return ""
}
switch mode {
case "set":
if path == "" {
return ""
}
return fmt.Sprintf("set %s = %s", path, formatParamOverrideAuditValue(value))
case "delete":
if path == "" {
return ""
}
return fmt.Sprintf("delete %s", path)
case "copy":
if from == "" || to == "" {
return ""
}
return fmt.Sprintf("copy %s -> %s", from, to)
case "move":
if from == "" || to == "" {
return ""
}
return fmt.Sprintf("move %s -> %s", from, to)
case "prepend":
if path == "" {
return ""
}
return fmt.Sprintf("prepend %s with %s", path, formatParamOverrideAuditValue(value))
case "append":
if path == "" {
return ""
}
return fmt.Sprintf("append %s with %s", path, formatParamOverrideAuditValue(value))
case "trim_prefix", "trim_suffix", "ensure_prefix", "ensure_suffix":
if path == "" {
return ""
}
return fmt.Sprintf("%s %s with %s", mode, path, formatParamOverrideAuditValue(value))
case "trim_space", "to_lower", "to_upper":
if path == "" {
return ""
}
return fmt.Sprintf("%s %s", mode, path)
case "replace", "regex_replace":
if path == "" {
return ""
}
return fmt.Sprintf("%s %s from %s to %s", mode, path, from, to)
case "set_header":
if path == "" {
return ""
}
return fmt.Sprintf("set_header %s = %s", path, formatParamOverrideAuditValue(value))
case "delete_header":
if path == "" {
return ""
}
return fmt.Sprintf("delete_header %s", path)
case "copy_header", "move_header":
if from == "" || to == "" {
return ""
}
return fmt.Sprintf("%s %s -> %s", mode, from, to)
case "pass_headers":
return fmt.Sprintf("pass_headers %s", formatParamOverrideAuditValue(value))
case "sync_fields":
if from == "" || to == "" {
return ""
}
return fmt.Sprintf("sync_fields %s -> %s", from, to)
case "return_error":
return fmt.Sprintf("return_error %s", formatParamOverrideAuditValue(value))
default:
if path == "" {
return mode
}
return fmt.Sprintf("%s %s", mode, path)
}
}
func getParamOverrideMap(info *RelayInfo) map[string]interface{} {
if info == nil || info.ChannelMeta == nil {
return nil
@@ -455,7 +655,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
}
// applyOperationsLegacy 原参数覆盖方法
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}, auditRecorder *paramOverrideAuditRecorder) ([]byte, error) {
reqMap := make(map[string]interface{})
err := common.Unmarshal(jsonData, &reqMap)
if err != nil {
@@ -464,6 +664,7 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
for key, value := range paramOverride {
reqMap[key] = value
auditRecorder.recordOperation("set", key, "", "", value)
}
return common.Marshal(reqMap)
@@ -471,6 +672,7 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
context := ensureContextMap(conditionContext)
auditRecorder := getParamOverrideAuditRecorder(context)
contextJSON, err := marshalContextJSON(context)
if err != nil {
return "", fmt.Errorf("failed to marshal condition context: %v", err)
@@ -506,6 +708,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("delete", path, "", "", nil)
}
case "set":
for _, path := range opPaths {
@@ -516,11 +719,15 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("set", path, "", "", op.Value)
}
case "move":
opFrom := processNegativeIndex(result, op.From)
opTo := processNegativeIndex(result, op.To)
result, err = moveValue(result, opFrom, opTo)
if err == nil {
auditRecorder.recordOperation("move", "", opFrom, opTo, nil)
}
case "copy":
if op.From == "" || op.To == "" {
return "", fmt.Errorf("copy from/to is required")
@@ -528,12 +735,16 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
opFrom := processNegativeIndex(result, op.From)
opTo := processNegativeIndex(result, op.To)
result, err = copyValue(result, opFrom, opTo)
if err == nil {
auditRecorder.recordOperation("copy", "", opFrom, opTo, nil)
}
case "prepend":
for _, path := range opPaths {
result, err = modifyValue(result, path, op.Value, op.KeepOrigin, true)
if err != nil {
break
}
auditRecorder.recordOperation("prepend", path, "", "", op.Value)
}
case "append":
for _, path := range opPaths {
@@ -541,6 +752,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("append", path, "", "", op.Value)
}
case "trim_prefix":
for _, path := range opPaths {
@@ -548,6 +760,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("trim_prefix", path, "", "", op.Value)
}
case "trim_suffix":
for _, path := range opPaths {
@@ -555,6 +768,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("trim_suffix", path, "", "", op.Value)
}
case "ensure_prefix":
for _, path := range opPaths {
@@ -562,6 +776,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("ensure_prefix", path, "", "", op.Value)
}
case "ensure_suffix":
for _, path := range opPaths {
@@ -569,6 +784,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("ensure_suffix", path, "", "", op.Value)
}
case "trim_space":
for _, path := range opPaths {
@@ -576,6 +792,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("trim_space", path, "", "", nil)
}
case "to_lower":
for _, path := range opPaths {
@@ -583,6 +800,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("to_lower", path, "", "", nil)
}
case "to_upper":
for _, path := range opPaths {
@@ -590,6 +808,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("to_upper", path, "", "", nil)
}
case "replace":
for _, path := range opPaths {
@@ -597,6 +816,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("replace", path, op.From, op.To, nil)
}
case "regex_replace":
for _, path := range opPaths {
@@ -604,8 +824,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if err != nil {
break
}
auditRecorder.recordOperation("regex_replace", path, op.From, op.To, nil)
}
case "return_error":
auditRecorder.recordOperation("return_error", op.Path, "", "", op.Value)
returnErr, parseErr := parseParamOverrideReturnError(op.Value)
if parseErr != nil {
return "", parseErr
@@ -621,11 +843,13 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
case "set_header":
err = setHeaderOverrideInContext(context, op.Path, op.Value, op.KeepOrigin)
if err == nil {
auditRecorder.recordOperation("set_header", op.Path, "", "", op.Value)
contextJSON, err = marshalContextJSON(context)
}
case "delete_header":
err = deleteHeaderOverrideInContext(context, op.Path)
if err == nil {
auditRecorder.recordOperation("delete_header", op.Path, "", "", nil)
contextJSON, err = marshalContextJSON(context)
}
case "copy_header":
@@ -642,6 +866,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
err = nil
}
if err == nil {
auditRecorder.recordOperation("copy_header", "", sourceHeader, targetHeader, nil)
contextJSON, err = marshalContextJSON(context)
}
case "move_header":
@@ -658,6 +883,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
err = nil
}
if err == nil {
auditRecorder.recordOperation("move_header", "", sourceHeader, targetHeader, nil)
contextJSON, err = marshalContextJSON(context)
}
case "pass_headers":
@@ -675,11 +901,13 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
}
}
if err == nil {
auditRecorder.recordOperation("pass_headers", "", "", "", headerNames)
contextJSON, err = marshalContextJSON(context)
}
case "sync_fields":
result, err = syncFieldsBetweenTargets(result, context, op.From, op.To)
if err == nil {
auditRecorder.recordOperation("sync_fields", "", op.From, op.To, nil)
contextJSON, err = marshalContextJSON(context)
}
default:
+100
View File
@@ -6,6 +6,7 @@ import (
"reflect"
"testing"
common2 "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/types"
"github.com/QuantumNous/new-api/dto"
@@ -2066,6 +2067,105 @@ func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out))
}
func TestApplyParamOverrideWithRelayInfoRecordsOperationAuditInDebugMode(t *testing.T) {
originalDebugEnabled := common2.DebugEnabled
common2.DebugEnabled = true
t.Cleanup(func() {
common2.DebugEnabled = originalDebugEnabled
})
info := &RelayInfo{
ChannelMeta: &ChannelMeta{
ParamOverride: map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "copy",
"from": "metadata.target_model",
"to": "model",
},
map[string]interface{}{
"mode": "set",
"path": "service_tier",
"value": "flex",
},
map[string]interface{}{
"mode": "set",
"path": "temperature",
"value": 0.1,
},
},
},
},
}
out, err := ApplyParamOverrideWithRelayInfo([]byte(`{
"model":"gpt-4.1",
"temperature":0.7,
"metadata":{"target_model":"gpt-4.1-mini"}
}`), info)
if err != nil {
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
}
assertJSONEqual(t, `{
"model":"gpt-4.1-mini",
"temperature":0.1,
"service_tier":"flex",
"metadata":{"target_model":"gpt-4.1-mini"}
}`, string(out))
expected := []string{
"copy metadata.target_model -> model",
"set service_tier = flex",
"set temperature = 0.1",
}
if !reflect.DeepEqual(info.ParamOverrideAudit, expected) {
t.Fatalf("unexpected param override audit, got %#v", info.ParamOverrideAudit)
}
}
func TestApplyParamOverrideWithRelayInfoRecordsOnlyKeyOperationsWhenDebugDisabled(t *testing.T) {
originalDebugEnabled := common2.DebugEnabled
common2.DebugEnabled = false
t.Cleanup(func() {
common2.DebugEnabled = originalDebugEnabled
})
info := &RelayInfo{
ChannelMeta: &ChannelMeta{
ParamOverride: map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "copy",
"from": "metadata.target_model",
"to": "model",
},
map[string]interface{}{
"mode": "set",
"path": "temperature",
"value": 0.1,
},
},
},
},
}
_, err := ApplyParamOverrideWithRelayInfo([]byte(`{
"model":"gpt-4.1",
"temperature":0.7,
"metadata":{"target_model":"gpt-4.1-mini"}
}`), info)
if err != nil {
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
}
expected := []string{
"copy metadata.target_model -> model",
}
if !reflect.DeepEqual(info.ParamOverrideAudit, expected) {
t.Fatalf("unexpected param override audit, got %#v", info.ParamOverrideAudit)
}
}
func assertJSONEqual(t *testing.T, want, got string) {
t.Helper()
+4 -6
View File
@@ -150,6 +150,7 @@ type RelayInfo struct {
LastError *types.NewAPIError
RuntimeHeadersOverride map[string]interface{}
UseRuntimeHeadersOverride bool
ParamOverrideAudit []string
PriceData types.PriceData
@@ -167,6 +168,8 @@ type RelayInfo struct {
// 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
FinalRequestRelayFormat types.RelayFormat
StreamStatus *StreamStatus
ThinkingContentInfo
TokenCountMeta
*ClaudeConvertInfo
@@ -343,15 +346,10 @@ func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
info.ClaudeConvertInfo = &ClaudeConvertInfo{
LastMessagesType: LastMessageTypeNone,
}
info.IsClaudeBetaQuery = c.Query("beta") == "true" || isClaudeBetaForced(c)
info.IsClaudeBetaQuery = c.Query("beta") == "true"
return info
}
func isClaudeBetaForced(c *gin.Context) bool {
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
return ok && channelOtherSettings.ClaudeBetaQuery
}
func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
info := genBaseRelayInfo(c, request)
info.RelayMode = relayconstant.RelayModeRerank
+112
View File
@@ -0,0 +1,112 @@
package common
import (
"fmt"
"strings"
"sync"
"time"
)
type StreamEndReason string
const (
StreamEndReasonNone StreamEndReason = ""
StreamEndReasonDone StreamEndReason = "done"
StreamEndReasonTimeout StreamEndReason = "timeout"
StreamEndReasonClientGone StreamEndReason = "client_gone"
StreamEndReasonScannerErr StreamEndReason = "scanner_error"
StreamEndReasonHandlerStop StreamEndReason = "handler_stop"
StreamEndReasonEOF StreamEndReason = "eof"
StreamEndReasonPanic StreamEndReason = "panic"
StreamEndReasonPingFail StreamEndReason = "ping_fail"
)
const maxStreamErrorEntries = 20
type StreamErrorEntry struct {
Message string
Timestamp time.Time
}
type StreamStatus struct {
EndReason StreamEndReason
EndError error
endOnce sync.Once
mu sync.Mutex
Errors []StreamErrorEntry
ErrorCount int
}
func NewStreamStatus() *StreamStatus {
return &StreamStatus{}
}
func (s *StreamStatus) SetEndReason(reason StreamEndReason, err error) {
if s == nil {
return
}
s.endOnce.Do(func() {
s.EndReason = reason
s.EndError = err
})
}
func (s *StreamStatus) RecordError(msg string) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.ErrorCount++
if len(s.Errors) < maxStreamErrorEntries {
s.Errors = append(s.Errors, StreamErrorEntry{
Message: msg,
Timestamp: time.Now(),
})
}
}
func (s *StreamStatus) HasErrors() bool {
if s == nil {
return false
}
s.mu.Lock()
defer s.mu.Unlock()
return s.ErrorCount > 0
}
func (s *StreamStatus) TotalErrorCount() int {
if s == nil {
return 0
}
s.mu.Lock()
defer s.mu.Unlock()
return s.ErrorCount
}
func (s *StreamStatus) IsNormalEnd() bool {
if s == nil {
return true
}
return s.EndReason == StreamEndReasonDone ||
s.EndReason == StreamEndReasonEOF ||
s.EndReason == StreamEndReasonHandlerStop
}
func (s *StreamStatus) Summary() string {
if s == nil {
return "StreamStatus<nil>"
}
b := &strings.Builder{}
fmt.Fprintf(b, "reason=%s", s.EndReason)
if s.EndError != nil {
fmt.Fprintf(b, " end_error=%q", s.EndError.Error())
}
s.mu.Lock()
if s.ErrorCount > 0 {
fmt.Fprintf(b, " soft_errors=%d", s.ErrorCount)
}
s.mu.Unlock()
return b.String()
}
+182
View File
@@ -0,0 +1,182 @@
package common
import (
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestStreamStatus_SetEndReason_FirstWins(t *testing.T) {
t.Parallel()
s := NewStreamStatus()
s.SetEndReason(StreamEndReasonDone, nil)
s.SetEndReason(StreamEndReasonTimeout, nil)
s.SetEndReason(StreamEndReasonClientGone, fmt.Errorf("context canceled"))
assert.Equal(t, StreamEndReasonDone, s.EndReason)
assert.Nil(t, s.EndError)
}
func TestStreamStatus_SetEndReason_WithError(t *testing.T) {
t.Parallel()
s := NewStreamStatus()
expectedErr := fmt.Errorf("read: connection reset")
s.SetEndReason(StreamEndReasonScannerErr, expectedErr)
assert.Equal(t, StreamEndReasonScannerErr, s.EndReason)
assert.Equal(t, expectedErr, s.EndError)
}
func TestStreamStatus_SetEndReason_NilSafe(t *testing.T) {
t.Parallel()
var s *StreamStatus
s.SetEndReason(StreamEndReasonDone, nil)
}
func TestStreamStatus_SetEndReason_Concurrent(t *testing.T) {
t.Parallel()
s := NewStreamStatus()
reasons := []StreamEndReason{
StreamEndReasonDone,
StreamEndReasonTimeout,
StreamEndReasonClientGone,
StreamEndReasonScannerErr,
StreamEndReasonHandlerStop,
StreamEndReasonEOF,
StreamEndReasonPanic,
StreamEndReasonPingFail,
}
var wg sync.WaitGroup
for _, r := range reasons {
wg.Add(1)
go func(reason StreamEndReason) {
defer wg.Done()
s.SetEndReason(reason, nil)
}(r)
}
wg.Wait()
assert.NotEqual(t, StreamEndReasonNone, s.EndReason)
}
func TestStreamStatus_RecordError_Basic(t *testing.T) {
t.Parallel()
s := NewStreamStatus()
s.RecordError("bad json")
s.RecordError("another bad json")
s.RecordError("client gone")
assert.True(t, s.HasErrors())
assert.Equal(t, 3, s.TotalErrorCount())
assert.Len(t, s.Errors, 3)
}
func TestStreamStatus_RecordError_CapAtMax(t *testing.T) {
t.Parallel()
s := NewStreamStatus()
for i := 0; i < 30; i++ {
s.RecordError(fmt.Sprintf("error_%d", i))
}
assert.Equal(t, maxStreamErrorEntries, len(s.Errors))
assert.Equal(t, 30, s.TotalErrorCount())
}
func TestStreamStatus_RecordError_NilSafe(t *testing.T) {
t.Parallel()
var s *StreamStatus
s.RecordError("should not panic")
}
func TestStreamStatus_RecordError_Concurrent(t *testing.T) {
t.Parallel()
s := NewStreamStatus()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
s.RecordError(fmt.Sprintf("error_%d", idx))
}(i)
}
wg.Wait()
assert.Equal(t, 100, s.TotalErrorCount())
assert.LessOrEqual(t, len(s.Errors), maxStreamErrorEntries)
}
func TestStreamStatus_HasErrors_Empty(t *testing.T) {
t.Parallel()
s := NewStreamStatus()
assert.False(t, s.HasErrors())
assert.Equal(t, 0, s.TotalErrorCount())
}
func TestStreamStatus_HasErrors_NilSafe(t *testing.T) {
t.Parallel()
var s *StreamStatus
assert.False(t, s.HasErrors())
assert.Equal(t, 0, s.TotalErrorCount())
}
func TestStreamStatus_IsNormalEnd(t *testing.T) {
t.Parallel()
tests := []struct {
reason StreamEndReason
normal bool
}{
{StreamEndReasonDone, true},
{StreamEndReasonEOF, true},
{StreamEndReasonHandlerStop, true},
{StreamEndReasonTimeout, false},
{StreamEndReasonClientGone, false},
{StreamEndReasonScannerErr, false},
{StreamEndReasonPanic, false},
{StreamEndReasonPingFail, false},
{StreamEndReasonNone, false},
}
for _, tt := range tests {
s := NewStreamStatus()
s.SetEndReason(tt.reason, nil)
assert.Equal(t, tt.normal, s.IsNormalEnd(), "reason=%s", tt.reason)
}
}
func TestStreamStatus_IsNormalEnd_NilSafe(t *testing.T) {
t.Parallel()
var s *StreamStatus
assert.True(t, s.IsNormalEnd())
}
func TestStreamStatus_Summary(t *testing.T) {
t.Parallel()
s := NewStreamStatus()
s.SetEndReason(StreamEndReasonDone, nil)
summary := s.Summary()
assert.Contains(t, summary, "reason=done")
assert.NotContains(t, summary, "soft_errors")
s2 := NewStreamStatus()
s2.SetEndReason(StreamEndReasonTimeout, nil)
s2.RecordError("bad json")
s2.RecordError("write failed")
summary2 := s2.Summary()
assert.Contains(t, summary2, "reason=timeout")
assert.Contains(t, summary2, "soft_errors=2")
}
func TestStreamStatus_Summary_NilSafe(t *testing.T) {
t.Parallel()
var s *StreamStatus
assert.Equal(t, "StreamStatus<nil>", s.Summary())
}
+2 -265
View File
@@ -6,26 +6,20 @@ import (
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/pkg/billingexpr"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
"github.com/samber/lo"
"github.com/shopspring/decimal"
"github.com/gin-gonic/gin"
)
@@ -94,7 +88,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
if containAudioTokens && containsAudioRatios {
service.PostAudioConsumeQuota(c, info, usage, "")
} else {
postConsumeQuota(c, info, usage)
service.PostTextConsumeQuota(c, info, usage, nil)
}
return nil
}
@@ -217,264 +211,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
if containAudioTokens && containsAudioRatios {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
postConsumeQuota(c, info, usage.(*dto.Usage))
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
}
return nil
}
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
originUsage := usage
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.GetEstimatePromptTokens(),
CompletionTokens: 0,
TotalTokens: relayInfo.GetEstimatePromptTokens(),
}
extraContent = append(extraContent, "上游无计费信息")
}
if originUsage != nil {
service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
}
// Tiered billing: only determines quota, logging continues through normal path
isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude
var tieredUsedVars map[string]bool
if snap := relayInfo.TieredBillingSnapshot; snap != nil {
tieredUsedVars = billingexpr.UsedVars(snap.ExprString)
}
var tieredResult *billingexpr.TieredResult
tieredOk, tieredQuota, tieredRes := service.TryTieredSettle(relayInfo, service.BuildTieredTokenParams(usage, isClaudeUsageSemantic, tieredUsedVars))
if tieredOk {
tieredResult = tieredRes
}
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
cacheTokens := usage.PromptTokensDetails.CachedTokens
imageTokens := usage.PromptTokensDetails.ImageTokens
audioTokens := usage.PromptTokensDetails.AudioTokens
completionTokens := usage.CompletionTokens
cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := relayInfo.PriceData.CompletionRatio
cacheRatio := relayInfo.PriceData.CacheRatio
imageRatio := relayInfo.PriceData.ImageRatio
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
dImageTokens := decimal.NewFromInt(int64(imageTokens))
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
dCompletionRatio := decimal.NewFromFloat(completionRatio)
dCacheRatio := decimal.NewFromFloat(cacheRatio)
dImageRatio := decimal.NewFromFloat(imageRatio)
dModelRatio := decimal.NewFromFloat(modelRatio)
dGroupRatio := decimal.NewFromFloat(groupRatio)
dModelPrice := decimal.NewFromFloat(modelPrice)
dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
ratio := dModelRatio.Mul(dGroupRatio)
// Collect tool call usage from context and relayInfo
toolUsage := service.ToolCallUsage{
ModelName: modelName,
ImageGenerationCall: ctx.GetBool("image_generation_call"),
ImageGenerationQuality: ctx.GetString("image_generation_call_quality"),
ImageGenerationSize: ctx.GetString("image_generation_call_size"),
}
if relayInfo.ResponsesUsageInfo != nil {
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
toolUsage.WebSearchCalls = webSearchTool.CallCount
toolUsage.WebSearchToolName = dto.BuildInToolWebSearchPreview
}
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
toolUsage.FileSearchCalls = fileSearchTool.CallCount
}
} else if strings.HasSuffix(modelName, "search-preview") {
toolUsage.WebSearchCalls = 1
toolUsage.WebSearchToolName = dto.BuildInToolWebSearchPreview
}
if claudeSearchCalls := ctx.GetInt("claude_web_search_requests"); claudeSearchCalls > 0 {
toolUsage.WebSearchCalls = claudeSearchCalls
toolUsage.WebSearchToolName = "web_search"
}
toolResult := service.ComputeToolCallQuota(toolUsage, groupRatio)
for _, item := range toolResult.Items {
extraContent = append(extraContent, fmt.Sprintf("%s 调用 %d 次,花费 %d", item.Name, item.CallCount, item.Quota))
}
var quotaCalculateDecimal decimal.Decimal
var audioInputQuota decimal.Decimal
var audioInputPrice float64
if !relayInfo.PriceData.UsePrice {
baseTokens := dPromptTokens
// 减去 cached tokens
// Anthropic API 的 input_tokens 已经不包含缓存 tokens,不需要减去
// OpenAI/OpenRouter 等 API 的 prompt_tokens 包含缓存 tokens,需要减去
var cachedTokensWithRatio decimal.Decimal
if !dCacheTokens.IsZero() {
if !isClaudeUsageSemantic {
baseTokens = baseTokens.Sub(dCacheTokens)
}
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
}
var dCachedCreationTokensWithRatio decimal.Decimal
if !dCachedCreationTokens.IsZero() {
if !isClaudeUsageSemantic {
baseTokens = baseTokens.Sub(dCachedCreationTokens)
}
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
}
// 减去 image tokens
var imageTokensWithRatio decimal.Decimal
if !dImageTokens.IsZero() {
baseTokens = baseTokens.Sub(dImageTokens)
imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
}
// 减去 Gemini audio tokens
if !dAudioTokens.IsZero() {
audioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(modelName)
if audioInputPrice > 0 {
// 重新计算 base tokens
baseTokens = baseTokens.Sub(dAudioTokens)
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()))
}
}
promptQuota := baseTokens.Add(cachedTokensWithRatio).
Add(imageTokensWithRatio).
Add(dCachedCreationTokensWithRatio)
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
quotaCalculateDecimal = decimal.NewFromInt(1)
}
} else {
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
}
// 添加 audio input 独立计费(Gemini 音频按 token 计价,不属于工具调用)
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
if len(relayInfo.PriceData.OtherRatios) > 0 {
for key, otherRatio := range relayInfo.PriceData.OtherRatios {
dOtherRatio := decimal.NewFromFloat(otherRatio)
quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio)
extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio))
}
}
quota := int(quotaCalculateDecimal.Round(0).IntPart())
if tieredOk {
quota = tieredQuota
}
// Tool call fees: add for per-token and tiered billing; skip for per-call (price includes everything)
if !relayInfo.PriceData.UsePrice && toolResult.TotalQuota > 0 {
quota += toolResult.TotalQuota
}
totalTokens := promptTokens + completionTokens
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
if !ratio.IsZero() && quota == 0 {
quota = 1
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
if err := service.SettleBilling(ctx, relayInfo, quota); err != nil {
logger.LogError(ctx, "error settling billing: "+err.Error())
}
logModel := modelName
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
}
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
logModel = "gpt-4o-gizmo-*"
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
}
logContent := strings.Join(extraContent, ", ")
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
if adminRejectReason != "" {
other["reject_reason"] = adminRejectReason
}
// For chat-based calls to the Claude model, tagging is required. Using Claude's rendering logs, the two approaches handle input rendering differently.
if isClaudeUsageSemantic {
other["claude"] = true
other["usage_semantic"] = "anthropic"
}
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio
other["image_output"] = imageTokens
}
if cachedCreationTokens != 0 {
other["cache_creation_tokens"] = cachedCreationTokens
other["cache_creation_ratio"] = cachedCreationRatio
}
for _, item := range toolResult.Items {
switch item.Name {
case "web_search", "claude_web_search":
other["web_search"] = true
other["web_search_call_count"] = item.CallCount
other["web_search_price"] = item.PricePer1K
case "file_search":
other["file_search"] = true
other["file_search_call_count"] = item.CallCount
other["file_search_price"] = item.PricePer1K
case "image_generation":
other["image_generation_call"] = true
other["image_generation_call_price"] = item.TotalPrice
}
}
if !audioInputQuota.IsZero() {
other["audio_input_seperate_price"] = true
other["audio_input_token_count"] = audioTokens
other["audio_input_price"] = audioInputPrice
}
if tieredResult != nil {
service.InjectTieredBillingInfo(other, relayInfo, tieredResult)
}
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
+1 -1
View File
@@ -83,6 +83,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, info, usage.(*dto.Usage))
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
return nil
}
+2 -2
View File
@@ -194,7 +194,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
return openaiErr
}
postConsumeQuota(c, info, usage.(*dto.Usage))
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
return nil
}
@@ -288,6 +288,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
return openaiErr
}
postConsumeQuota(c, info, usage.(*dto.Usage))
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
return nil
}
+52
View File
@@ -0,0 +1,52 @@
package helper
import (
relaycommon "github.com/QuantumNous/new-api/relay/common"
)
// StreamResult is passed to each dataHandler invocation, providing methods
// to record soft errors, signal fatal stops, or mark normal completion.
// StreamScannerHandler checks IsStopped() after each callback invocation.
type StreamResult struct {
status *relaycommon.StreamStatus
stopped bool
}
func newStreamResult(status *relaycommon.StreamStatus) *StreamResult {
return &StreamResult{status: status}
}
// Error records a soft error. The stream continues processing.
// Can be called multiple times per chunk.
func (r *StreamResult) Error(err error) {
if err == nil {
return
}
r.status.RecordError(err.Error())
}
// Stop records a fatal error and marks the stream to stop after this chunk.
func (r *StreamResult) Stop(err error) {
if err != nil {
r.status.RecordError(err.Error())
}
r.status.SetEndReason(relaycommon.StreamEndReasonHandlerStop, err)
r.stopped = true
}
// Done signals that the handler has finished processing normally
// (e.g., Dify "message_end"). The stream stops after this chunk.
func (r *StreamResult) Done() {
r.status.SetEndReason(relaycommon.StreamEndReasonDone, nil)
r.stopped = true
}
// IsStopped returns whether Stop() or Done() was called during this chunk.
func (r *StreamResult) IsStopped() bool {
return r.stopped
}
// reset clears the per-chunk stopped flag so the object can be reused.
func (r *StreamResult) reset() {
r.stopped = false
}
+26 -10
View File
@@ -34,12 +34,15 @@ func getScannerBufferSize() int {
return DefaultMaxScannerBufferSize
}
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) {
if resp == nil || dataHandler == nil {
return
}
// 无条件新建 StreamStatus
info.StreamStatus = relaycommon.NewStreamStatus()
// 确保响应体总是被关闭
defer func() {
if resp.Body != nil {
@@ -121,6 +124,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
wg.Done()
if r := recover(); r != nil {
logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("ping panic: %v", r))
common.SafeSendBool(stopChan, true)
}
if common.DebugEnabled {
@@ -148,6 +152,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
case err := <-done:
if err != nil {
logger.LogError(c, "ping data error: "+err.Error())
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, err)
return
}
if common.DebugEnabled {
@@ -155,6 +160,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
}
case <-time.After(10 * time.Second):
logger.LogError(c, "ping data send timeout")
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, fmt.Errorf("ping send timeout"))
return
case <-ctx.Done():
return
@@ -184,14 +190,17 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
wg.Done()
if r := recover(); r != nil {
logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("handler panic: %v", r))
}
common.SafeSendBool(stopChan, true)
}()
sr := newStreamResult(info.StreamStatus)
for data := range dataChan {
sr.reset()
writeMutex.Lock()
success := dataHandler(data)
dataHandler(data, sr)
writeMutex.Unlock()
if !success {
if sr.IsStopped() {
return
}
}
@@ -205,6 +214,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
wg.Done()
if r := recover(); r != nil {
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("scanner panic: %v", r))
}
common.SafeSendBool(stopChan, true)
if common.DebugEnabled {
@@ -220,6 +230,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
case <-ctx.Done():
return
case <-c.Request.Context().Done():
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
return
default:
}
@@ -253,7 +264,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
return
}
} else {
// done, 处理完成标志,直接退出停止读取剩余数据防止出错
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
if common.DebugEnabled {
println("received [DONE], stopping scanner")
}
@@ -264,20 +275,25 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
if err := scanner.Err(); err != nil {
if err != io.EOF {
logger.LogError(c, "scanner error: "+err.Error())
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, err)
}
}
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
})
// 主循环等待完成或超时
select {
case <-ticker.C:
// 超时处理逻辑
logger.LogError(c, "streaming timeout")
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonTimeout, nil)
case <-stopChan:
// 正常结束
logger.LogInfo(c, "streaming finished")
// EndReason already set by the goroutine that triggered stopChan
case <-c.Request.Context().Done():
// 客户端断开连接
logger.LogInfo(c, "client disconnected")
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
}
if info.StreamStatus.IsNormalEnd() && !info.StreamStatus.HasErrors() {
logger.LogInfo(c, fmt.Sprintf("stream ended: %s", info.StreamStatus.Summary()))
} else {
logger.LogError(c, fmt.Sprintf("stream ended: %s, received=%d", info.StreamStatus.Summary(), info.ReceivedResponseCount))
}
}
+216 -47
View File
@@ -56,8 +56,6 @@ func buildSSEBody(n int) string {
return b.String()
}
// slowReader wraps a reader and injects a delay before each Read call,
// simulating a slow upstream that trickles data.
type slowReader struct {
r io.Reader
delay time.Duration
@@ -79,7 +77,7 @@ func TestStreamScannerHandler_NilInputs(t *testing.T) {
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
StreamScannerHandler(c, nil, info, func(data string) bool { return true })
StreamScannerHandler(c, nil, info, func(data string, sr *StreamResult) {})
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
}
@@ -89,9 +87,8 @@ func TestStreamScannerHandler_EmptyBody(t *testing.T) {
c, resp, info := setupStreamTest(t, strings.NewReader(""))
var called atomic.Bool
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
called.Store(true)
return true
})
assert.False(t, called.Load(), "handler should not be called for empty body")
@@ -105,9 +102,8 @@ func TestStreamScannerHandler_1000Chunks(t *testing.T) {
c, resp, info := setupStreamTest(t, strings.NewReader(body))
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
count.Add(1)
return true
})
assert.Equal(t, int64(numChunks), count.Load())
@@ -124,9 +120,8 @@ func TestStreamScannerHandler_10000Chunks(t *testing.T) {
var count atomic.Int64
start := time.Now()
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
count.Add(1)
return true
})
elapsed := time.Since(start)
@@ -145,11 +140,10 @@ func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
var mu sync.Mutex
received := make([]string, 0, numChunks)
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
mu.Lock()
received = append(received, data)
mu.Unlock()
return true
})
require.Equal(t, numChunks, len(received))
@@ -166,31 +160,32 @@ func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
c, resp, info := setupStreamTest(t, strings.NewReader(body))
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
count.Add(1)
return true
})
assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
}
func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
func TestStreamScannerHandler_StopStopsStream(t *testing.T) {
t.Parallel()
const numChunks = 200
body := buildSSEBody(numChunks)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
const failAt = 50
const stopAt int64 = 50
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
n := count.Add(1)
return n < failAt
if n >= stopAt {
sr.Stop(fmt.Errorf("fatal at %d", n))
}
})
// The worker stops at failAt; the scanner may have read ahead,
// but the handler should not be called beyond failAt.
assert.Equal(t, int64(failAt), count.Load())
assert.Equal(t, stopAt, count.Load())
require.NotNil(t, info.StreamStatus)
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
}
func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
@@ -210,9 +205,8 @@ func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
count.Add(1)
return true
})
assert.Equal(t, int64(100), count.Load())
@@ -225,25 +219,18 @@ func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
c, resp, info := setupStreamTest(t, strings.NewReader(body))
var got string
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
got = data
return true
})
assert.Equal(t, "{\"trimmed\":true}", got)
}
// ---------- Decoupling: scanner not blocked by slow handler ----------
// ---------- Decoupling ----------
func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
t.Parallel()
// Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
// If the scanner were synchronously coupled to the handler, total time would be
// ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
// With decoupling, total time should be closer to
// ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
// because the scanner reads ahead into the buffer while the handler processes.
const numChunks = 50
const upstreamDelay = 10 * time.Millisecond
const handlerDelay = 20 * time.Millisecond
@@ -273,10 +260,9 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
start := time.Now()
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
time.Sleep(handlerDelay)
count.Add(1)
return true
})
close(done)
}()
@@ -293,7 +279,6 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
// If decoupled, elapsed should be well under the coupled estimate.
assert.Less(t, elapsed, coupledTime*85/100,
"decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
}
@@ -311,9 +296,8 @@ func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
count.Add(1)
return true
})
close(done)
}()
@@ -344,8 +328,6 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
setting.PingIntervalSeconds = oldSeconds
})
// Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
// The ping interval is 1s, so we should see at least 2 pings.
pr, pw := io.Pipe()
go func() {
defer pw.Close()
@@ -372,9 +354,8 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
var count atomic.Int64
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
count.Add(1)
return true
})
close(done)
}()
@@ -436,9 +417,8 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
var count atomic.Int64
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
count.Add(1)
return true
})
close(done)
}()
@@ -456,6 +436,199 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
}
// ---------- StreamStatus integration ----------
func TestStreamScannerHandler_StreamStatus_DoneReason(t *testing.T) {
t.Parallel()
body := buildSSEBody(10)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
require.NotNil(t, info.StreamStatus)
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
assert.Nil(t, info.StreamStatus.EndError)
assert.True(t, info.StreamStatus.IsNormalEnd())
assert.False(t, info.StreamStatus.HasErrors())
}
func TestStreamScannerHandler_StreamStatus_EOFWithoutDone(t *testing.T) {
t.Parallel()
var b strings.Builder
for i := 0; i < 5; i++ {
fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
}
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
require.NotNil(t, info.StreamStatus)
assert.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
assert.True(t, info.StreamStatus.IsNormalEnd())
}
func TestStreamScannerHandler_StreamStatus_HandlerStop(t *testing.T) {
t.Parallel()
body := buildSSEBody(100)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
n := count.Add(1)
if n >= 10 {
sr.Stop(fmt.Errorf("stop at 10"))
}
})
require.NotNil(t, info.StreamStatus)
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
assert.True(t, info.StreamStatus.HasErrors())
}
func TestStreamScannerHandler_StreamStatus_HandlerDone(t *testing.T) {
t.Parallel()
body := buildSSEBody(20)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
n := count.Add(1)
if n >= 5 {
sr.Done()
}
})
assert.Equal(t, int64(5), count.Load())
require.NotNil(t, info.StreamStatus)
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
assert.False(t, info.StreamStatus.HasErrors())
}
func TestStreamScannerHandler_StreamStatus_Timeout(t *testing.T) {
// Not parallel: modifies global constant.StreamingTimeout
oldTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 2
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
pr, pw := io.Pipe()
go func() {
fmt.Fprint(pw, "data: {\"id\":1}\n")
time.Sleep(10 * time.Second)
pw.Close()
}()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
resp := &http.Response{Body: pr}
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
close(done)
}()
select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatal("timed out waiting for stream timeout")
}
require.NotNil(t, info.StreamStatus)
assert.Equal(t, relaycommon.StreamEndReasonTimeout, info.StreamStatus.EndReason)
assert.False(t, info.StreamStatus.IsNormalEnd())
}
func TestStreamScannerHandler_StreamStatus_SoftErrors(t *testing.T) {
t.Parallel()
body := buildSSEBody(10)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
sr.Error(fmt.Errorf("soft error for chunk"))
})
require.NotNil(t, info.StreamStatus)
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
assert.True(t, info.StreamStatus.HasErrors())
assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
}
func TestStreamScannerHandler_StreamStatus_MultipleErrorsPerChunk(t *testing.T) {
t.Parallel()
body := buildSSEBody(5)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
sr.Error(fmt.Errorf("error A"))
sr.Error(fmt.Errorf("error B"))
})
require.NotNil(t, info.StreamStatus)
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
}
func TestStreamScannerHandler_StreamStatus_ErrorThenStop(t *testing.T) {
t.Parallel()
// Use a large body without [DONE] to avoid race between scanner's [DONE]
// and handler's Stop on the sync.Once EndReason.
var b strings.Builder
for i := 0; i < 100; i++ {
fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
}
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
count.Add(1)
sr.Error(fmt.Errorf("soft error"))
sr.Stop(fmt.Errorf("fatal"))
})
assert.Equal(t, int64(1), count.Load())
require.NotNil(t, info.StreamStatus)
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
assert.Equal(t, 2, info.StreamStatus.TotalErrorCount())
}
func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) {
t.Parallel()
body := buildSSEBody(1)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
assert.Nil(t, info.StreamStatus)
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
assert.NotNil(t, info.StreamStatus)
}
func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
t.Parallel()
body := buildSSEBody(5)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
info.StreamStatus = relaycommon.NewStreamStatus()
info.StreamStatus.RecordError("pre-existing error")
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
assert.Equal(t, 1, info.StreamStatus.TotalErrorCount())
}
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
t.Parallel()
@@ -469,9 +642,6 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
setting.PingIntervalSeconds = oldSeconds
})
// Slow upstream + slow handler. Total stream takes ~5 seconds.
// The ping goroutine stays alive as long as the scanner is reading,
// so pings should fire between data writes.
pr, pw := io.Pipe()
go func() {
defer pw.Close()
@@ -498,9 +668,8 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
var count atomic.Int64
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string) bool {
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
count.Add(1)
return true
})
close(done)
}()
+12 -3
View File
@@ -117,11 +117,20 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
if request.N != nil {
imageN = *request.N
}
// n is handled via OtherRatio so it is applied exactly once in quota
// calculation (both price-based and ratio-based paths).
// Adaptors may have already set a more accurate count from the
// upstream response; only set the default when they haven't.
if _, hasN := info.PriceData.OtherRatios["n"]; !hasN {
info.PriceData.AddOtherRatio("n", float64(imageN))
}
if usage.(*dto.Usage).TotalTokens == 0 {
usage.(*dto.Usage).TotalTokens = int(imageN)
usage.(*dto.Usage).TotalTokens = 1
}
if usage.(*dto.Usage).PromptTokens == 0 {
usage.(*dto.Usage).PromptTokens = int(imageN)
usage.(*dto.Usage).PromptTokens = 1
}
quality := "standard"
@@ -141,6 +150,6 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
}
postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), logContent)
return nil
}
+7
View File
@@ -49,6 +49,13 @@ func RelayMidjourneyImage(c *gin.Context) {
if httpClient == nil {
httpClient = service.GetHttpClient()
}
fetchSetting := system_setting.GetFetchSetting()
if err := common.ValidateURLWithFetchSetting(midjourneyTask.ImageUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
c.JSON(http.StatusForbidden, gin.H{
"error": fmt.Sprintf("request blocked: %v", err),
})
return
}
resp, err := httpClient.Get(midjourneyTask.ImageUrl)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
+1 -1
View File
@@ -96,6 +96,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, info, usage.(*dto.Usage))
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
return nil
}
+2 -2
View File
@@ -145,7 +145,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
info.PriceData = originPriceData
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
postConsumeQuota(c, info, usageDto)
service.PostTextConsumeQuota(c, info, usageDto, nil)
info.OriginModelName = originModelName
info.PriceData = originPriceData
@@ -155,7 +155,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, info, usageDto, "")
} else {
postConsumeQuota(c, info, usageDto)
service.PostTextConsumeQuota(c, info, usageDto, nil)
}
return nil
}