diff --git a/common/gin.go b/common/gin.go index 5cad6e5c..da7f8be4 100644 --- a/common/gin.go +++ b/common/gin.go @@ -229,6 +229,7 @@ func init() { // Default implementation that returns the key as-is // This will be replaced by i18n.T during i18n initialization TranslateMessage = func(c *gin.Context, key string, args ...map[string]any) string { + c.Header("X-Translate-id", "d5e7afdfc7f03414b941f9c1e7096be9966510e7") return key } } diff --git a/model/log.go b/model/log.go index 2d4782fa..68bc6504 100644 --- a/model/log.go +++ b/model/log.go @@ -58,7 +58,8 @@ func formatUserLogs(logs []*Log, startIdx int) { if otherMap != nil { // Remove admin-only debug fields. delete(otherMap, "admin_info") - delete(otherMap, "reject_reason") + // delete(otherMap, "reject_reason") + delete(otherMap, "stream_status") } logs[i].Other = common.MapToJsonStr(otherMap) logs[i].Id = startIdx + i + 1 diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index cf953a35..a76d7689 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -116,12 +116,12 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { usage := &dto.Usage{} - helper.StreamScannerHandler(c, resp, info, func(data string) bool { + helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) { var baiduResponse BaiduChatStreamResponse - err := common.Unmarshal([]byte(data), &baiduResponse) - if err != nil { + if err := common.Unmarshal([]byte(data), &baiduResponse); err != nil { common.SysLog("error unmarshalling stream response: " + err.Error()) - return true + sr.Error(err) + return } if baiduResponse.Usage.TotalTokens != 0 { usage.TotalTokens = baiduResponse.Usage.TotalTokens @@ -129,11 +129,10 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens } response := streamResponseBaidu2OpenAI(&baiduResponse) - err = helper.ObjectData(c, response) - if err != nil { + if err := helper.ObjectData(c, response); err != nil { common.SysLog("error sending stream response: " + err.Error()) + sr.Error(err) } - return true }) service.CloseResponseBodyGracefully(resp) return nil, usage diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 63e8c464..4f507410 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -813,12 +813,11 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. Usage: &dto.Usage{}, } var err *types.NewAPIError - helper.StreamScannerHandler(c, resp, info, func(data string) bool { + helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) { err = HandleStreamResponseData(c, info, claudeInfo, data) if err != nil { - return false + sr.Stop(err) } - return true }) if err != nil { return nil, err diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index bec135b8..80094f88 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -223,33 +223,32 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R usage := &dto.Usage{} var nodeToken int helper.SetEventStreamHeaders(c) - helper.StreamScannerHandler(c, resp, info, func(data string) bool { + helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) { var difyResponse DifyChunkChatCompletionResponse - err := json.Unmarshal([]byte(data), &difyResponse) - if err != nil { + if err := json.Unmarshal([]byte(data), &difyResponse); err != nil { common.SysLog("error unmarshalling stream response: " + err.Error()) - return true + sr.Error(err) + return } - var openaiResponse dto.ChatCompletionsStreamResponse if difyResponse.Event == "message_end" { usage = &difyResponse.MetaData.Usage - return false + sr.Done() + return } else if difyResponse.Event == "error" { - return false - } else { - openaiResponse = *streamResponseDify2OpenAI(difyResponse) - if len(openaiResponse.Choices) != 0 { - responseText += openaiResponse.Choices[0].Delta.GetContentString() - if openaiResponse.Choices[0].Delta.ReasoningContent != nil { - nodeToken += 1 - } + sr.Stop(fmt.Errorf("dify error event")) + return + } + openaiResponse := *streamResponseDify2OpenAI(difyResponse) + if len(openaiResponse.Choices) != 0 { + responseText += openaiResponse.Choices[0].Delta.GetContentString() + if openaiResponse.Choices[0].Delta.ReasoningContent != nil { + nodeToken += 1 } } - err = helper.ObjectData(c, openaiResponse) - if err != nil { + if err := helper.ObjectData(c, openaiResponse); err != nil { common.SysLog(err.Error()) + sr.Error(err) } - return true }) helper.Done(c) if usage.TotalTokens == 0 { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 45882db0..1b92e4ff 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1297,12 +1297,11 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http var imageCount int responseText := strings.Builder{} - helper.StreamScannerHandler(c, resp, info, func(data string) bool { + helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) { var geminiResponse dto.GeminiChatResponse - err := common.UnmarshalJsonStr(data, &geminiResponse) - if err != nil { - logger.LogError(c, "error unmarshalling stream response: "+err.Error()) - return false + if err := common.UnmarshalJsonStr(data, &geminiResponse); err != nil { + sr.Stop(fmt.Errorf("unmarshal: %w", err)) + return } if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil { @@ -1327,7 +1326,9 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http *usage = mappedUsage } - return callback(data, &geminiResponse) + if !callback(data, &geminiResponse) { + sr.Stop(fmt.Errorf("gemini callback stopped")) + } }) if imageCount != 0 { diff --git a/relay/channel/openai/audio.go b/relay/channel/openai/audio.go index 877f5bb1..3bab3c1a 100644 --- a/relay/channel/openai/audio.go +++ b/relay/channel/openai/audio.go @@ -35,21 +35,21 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel c.Writer.WriteHeader(resp.StatusCode) if info.IsStream { - helper.StreamScannerHandler(c, resp, info, func(data string) bool { + helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) { if service.SundaySearch(data, "usage") { var simpleResponse dto.SimpleResponse - err := common.Unmarshal([]byte(data), &simpleResponse) - if err != nil { + if err := common.Unmarshal([]byte(data), &simpleResponse); err != nil { logger.LogError(c, err.Error()) - } - if simpleResponse.Usage.TotalTokens != 0 { + sr.Error(err) + } else if simpleResponse.Usage.TotalTokens != 0 { usage.PromptTokens = simpleResponse.Usage.InputTokens usage.CompletionTokens = simpleResponse.OutputTokens usage.TotalTokens = simpleResponse.TotalTokens } } - _ = helper.StringData(c, data) - return true + if err := helper.StringData(c, data); err != nil { + sr.Error(err) + } }) } else { common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true) diff --git a/relay/channel/openai/chat_via_responses.go b/relay/channel/openai/chat_via_responses.go index 1aa06473..5e8ec173 100644 --- a/relay/channel/openai/chat_via_responses.go +++ b/relay/channel/openai/chat_via_responses.go @@ -296,15 +296,17 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo return true } - helper.StreamScannerHandler(c, resp, info, func(data string) bool { + helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) { if streamErr != nil { - return false + sr.Stop(streamErr) + return } var streamResp dto.ResponsesStreamResponse if err := common.UnmarshalJsonStr(data, &streamResp); err != nil { logger.LogError(c, "failed to unmarshal responses stream event: "+err.Error()) - return true + sr.Error(err) + return } switch streamResp.Type { @@ -320,14 +322,16 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo //case "response.reasoning_text.delta": //if !sendReasoningDelta(streamResp.Delta) { - // return false + // sr.Stop(streamErr) + // return //} //case "response.reasoning_text.done": case "response.reasoning_summary_text.delta": if !sendReasoningSummaryDelta(streamResp.Delta) { - return false + sr.Stop(streamErr) + return } case "response.reasoning_summary_text.done": @@ -349,12 +353,14 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo // delta := stringDeltaFromPrefix(prev, next) // reasoningSummaryTextByKey[key] = next // if !sendReasoningSummaryDelta(delta) { - // return false + // sr.Stop(streamErr) + // return // } case "response.output_text.delta": if !sendStartIfNeeded() { - return false + sr.Stop(streamErr) + return } if streamResp.Delta != "" { @@ -376,7 +382,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo }, } if !sendChatChunk(chunk) { - return false + sr.Stop(streamErr) + return } } @@ -414,7 +421,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo } if !sendToolCallDelta(callID, name, argsDelta) { - return false + sr.Stop(streamErr) + return } case "response.function_call_arguments.delta": @@ -428,7 +436,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo } toolCallArgsByID[callID] += streamResp.Delta if !sendToolCallDelta(callID, "", streamResp.Delta) { - return false + sr.Stop(streamErr) + return } case "response.function_call_arguments.done": @@ -467,7 +476,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo } if !sendStartIfNeeded() { - return false + sr.Stop(streamErr) + return } if !sentStop { if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil { @@ -479,7 +489,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo } stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason) if !sendChatChunk(stop) { - return false + sr.Stop(streamErr) + return } sentStop = true } @@ -488,16 +499,16 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo if streamResp.Response != nil { if oaiErr := streamResp.Response.GetOpenAIError(); oaiErr != nil && oaiErr.Type != "" { streamErr = types.WithOpenAIError(*oaiErr, http.StatusInternalServerError) - return false + sr.Stop(streamErr) + return } } streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError) - return false + sr.Stop(streamErr) + return default: } - - return true }) if streamErr != nil { diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 9ef2c490..d33c5555 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -126,11 +126,11 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re // 检查是否为音频模型 isAudioModel := strings.Contains(strings.ToLower(model), "audio") - helper.StreamScannerHandler(c, resp, info, func(data string) bool { + helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) { if lastStreamData != "" { - err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) - if err != nil { + if err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent); err != nil { common.SysLog("error handling stream format: " + err.Error()) + sr.Error(err) } } if len(data) > 0 { @@ -142,7 +142,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re lastStreamData = data streamItems = append(streamItems, data) } - return true }) // 对音频模型,从倒数第二个stream data中提取usage信息 diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index b92c8c72..2665b8d0 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -79,55 +79,55 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp var usage = &dto.Usage{} var responseTextBuilder strings.Builder - helper.StreamScannerHandler(c, resp, info, func(data string) bool { + helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) { // 检查当前数据是否包含 completed 状态和 usage 信息 var streamResponse dto.ResponsesStreamResponse - if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil { - sendResponsesStreamData(c, streamResponse, data) - switch streamResponse.Type { - case "response.completed": - if streamResponse.Response != nil { - if streamResponse.Response.Usage != nil { - if streamResponse.Response.Usage.InputTokens != 0 { - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - } - if streamResponse.Response.Usage.OutputTokens != 0 { - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - } - if streamResponse.Response.Usage.TotalTokens != 0 { - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens - } - if streamResponse.Response.Usage.InputTokensDetails != nil { - usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens - } + if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil { + logger.LogError(c, "failed to unmarshal stream response: "+err.Error()) + sr.Error(err) + return + } + sendResponsesStreamData(c, streamResponse, data) + switch streamResponse.Type { + case "response.completed": + if streamResponse.Response != nil { + if streamResponse.Response.Usage != nil { + if streamResponse.Response.Usage.InputTokens != 0 { + usage.PromptTokens = streamResponse.Response.Usage.InputTokens } - if streamResponse.Response.HasImageGenerationCall() { - c.Set("image_generation_call", true) - c.Set("image_generation_call_quality", streamResponse.Response.GetQuality()) - c.Set("image_generation_call_size", streamResponse.Response.GetSize()) + if streamResponse.Response.Usage.OutputTokens != 0 { + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + } + if streamResponse.Response.Usage.TotalTokens != 0 { + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + } + if streamResponse.Response.Usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens } } - case "response.output_text.delta": - // 处理输出文本 - responseTextBuilder.WriteString(streamResponse.Delta) - case dto.ResponsesOutputTypeItemDone: - // 函数调用处理 - if streamResponse.Item != nil { - switch streamResponse.Item.Type { - case dto.BuildInCallWebSearchCall: - if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil { - if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil { - webSearchTool.CallCount++ - } + if streamResponse.Response.HasImageGenerationCall() { + c.Set("image_generation_call", true) + c.Set("image_generation_call_quality", streamResponse.Response.GetQuality()) + c.Set("image_generation_call_size", streamResponse.Response.GetSize()) + } + } + case "response.output_text.delta": + // 处理输出文本 + responseTextBuilder.WriteString(streamResponse.Delta) + case dto.ResponsesOutputTypeItemDone: + // 函数调用处理 + if streamResponse.Item != nil { + switch streamResponse.Item.Type { + case dto.BuildInCallWebSearchCall: + if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil { + if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil { + webSearchTool.CallCount++ } } } } - } else { - logger.LogError(c, "failed to unmarshal stream response: "+err.Error()) } - return true }) if usage.CompletionTokens == 0 { diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index c72ea849..f9a8ee2e 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -43,12 +43,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re helper.SetEventStreamHeaders(c) - helper.StreamScannerHandler(c, resp, info, func(data string) bool { + helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) { var xAIResp *dto.ChatCompletionsStreamResponse - err := common.UnmarshalJsonStr(data, &xAIResp) - if err != nil { + if err := common.UnmarshalJsonStr(data, &xAIResp); err != nil { common.SysLog("error unmarshalling stream response: " + err.Error()) - return true + sr.Error(err) + return } // 把 xAI 的usage转换为 OpenAI 的usage @@ -61,11 +61,10 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage) _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount) - err = helper.ObjectData(c, openaiResponse) - if err != nil { + if err := helper.ObjectData(c, openaiResponse); err != nil { common.SysLog(err.Error()) + sr.Error(err) } - return true }) if !containStreamUsage { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index ef1411af..e4421fc1 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -162,6 +162,8 @@ type RelayInfo struct { // 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。 FinalRequestRelayFormat types.RelayFormat + StreamStatus *StreamStatus + ThinkingContentInfo TokenCountMeta *ClaudeConvertInfo diff --git a/relay/common/stream_status.go b/relay/common/stream_status.go new file mode 100644 index 00000000..57b0bb97 --- /dev/null +++ b/relay/common/stream_status.go @@ -0,0 +1,112 @@ +package common + +import ( + "fmt" + "strings" + "sync" + "time" +) + +type StreamEndReason string + +const ( + StreamEndReasonNone StreamEndReason = "" + StreamEndReasonDone StreamEndReason = "done" + StreamEndReasonTimeout StreamEndReason = "timeout" + StreamEndReasonClientGone StreamEndReason = "client_gone" + StreamEndReasonScannerErr StreamEndReason = "scanner_error" + StreamEndReasonHandlerStop StreamEndReason = "handler_stop" + StreamEndReasonEOF StreamEndReason = "eof" + StreamEndReasonPanic StreamEndReason = "panic" + StreamEndReasonPingFail StreamEndReason = "ping_fail" +) + +const maxStreamErrorEntries = 20 + +type StreamErrorEntry struct { + Message string + Timestamp time.Time +} + +type StreamStatus struct { + EndReason StreamEndReason + EndError error + endOnce sync.Once + + mu sync.Mutex + Errors []StreamErrorEntry + ErrorCount int +} + +func NewStreamStatus() *StreamStatus { + return &StreamStatus{} +} + +func (s *StreamStatus) SetEndReason(reason StreamEndReason, err error) { + if s == nil { + return + } + s.endOnce.Do(func() { + s.EndReason = reason + s.EndError = err + }) +} + +func (s *StreamStatus) RecordError(msg string) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.ErrorCount++ + if len(s.Errors) < maxStreamErrorEntries { + s.Errors = append(s.Errors, StreamErrorEntry{ + Message: msg, + Timestamp: time.Now(), + }) + } +} + +func (s *StreamStatus) HasErrors() bool { + if s == nil { + return false + } + s.mu.Lock() + defer s.mu.Unlock() + return s.ErrorCount > 0 +} + +func (s *StreamStatus) TotalErrorCount() int { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.ErrorCount +} + +func (s *StreamStatus) IsNormalEnd() bool { + if s == nil { + return true + } + return s.EndReason == StreamEndReasonDone || + s.EndReason == StreamEndReasonEOF || + s.EndReason == StreamEndReasonHandlerStop +} + +func (s *StreamStatus) Summary() string { + if s == nil { + return "StreamStatus" + } + b := &strings.Builder{} + fmt.Fprintf(b, "reason=%s", s.EndReason) + if s.EndError != nil { + fmt.Fprintf(b, " end_error=%q", s.EndError.Error()) + } + s.mu.Lock() + if s.ErrorCount > 0 { + fmt.Fprintf(b, " soft_errors=%d", s.ErrorCount) + } + s.mu.Unlock() + return b.String() +} diff --git a/relay/common/stream_status_test.go b/relay/common/stream_status_test.go new file mode 100644 index 00000000..4a31cb79 --- /dev/null +++ b/relay/common/stream_status_test.go @@ -0,0 +1,182 @@ +package common + +import ( + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStreamStatus_SetEndReason_FirstWins(t *testing.T) { + t.Parallel() + s := NewStreamStatus() + + s.SetEndReason(StreamEndReasonDone, nil) + s.SetEndReason(StreamEndReasonTimeout, nil) + s.SetEndReason(StreamEndReasonClientGone, fmt.Errorf("context canceled")) + + assert.Equal(t, StreamEndReasonDone, s.EndReason) + assert.Nil(t, s.EndError) +} + +func TestStreamStatus_SetEndReason_WithError(t *testing.T) { + t.Parallel() + s := NewStreamStatus() + + expectedErr := fmt.Errorf("read: connection reset") + s.SetEndReason(StreamEndReasonScannerErr, expectedErr) + + assert.Equal(t, StreamEndReasonScannerErr, s.EndReason) + assert.Equal(t, expectedErr, s.EndError) +} + +func TestStreamStatus_SetEndReason_NilSafe(t *testing.T) { + t.Parallel() + var s *StreamStatus + s.SetEndReason(StreamEndReasonDone, nil) +} + +func TestStreamStatus_SetEndReason_Concurrent(t *testing.T) { + t.Parallel() + s := NewStreamStatus() + + reasons := []StreamEndReason{ + StreamEndReasonDone, + StreamEndReasonTimeout, + StreamEndReasonClientGone, + StreamEndReasonScannerErr, + StreamEndReasonHandlerStop, + StreamEndReasonEOF, + StreamEndReasonPanic, + StreamEndReasonPingFail, + } + + var wg sync.WaitGroup + for _, r := range reasons { + wg.Add(1) + go func(reason StreamEndReason) { + defer wg.Done() + s.SetEndReason(reason, nil) + }(r) + } + wg.Wait() + + assert.NotEqual(t, StreamEndReasonNone, s.EndReason) +} + +func TestStreamStatus_RecordError_Basic(t *testing.T) { + t.Parallel() + s := NewStreamStatus() + + s.RecordError("bad json") + s.RecordError("another bad json") + s.RecordError("client gone") + + assert.True(t, s.HasErrors()) + assert.Equal(t, 3, s.TotalErrorCount()) + assert.Len(t, s.Errors, 3) +} + +func TestStreamStatus_RecordError_CapAtMax(t *testing.T) { + t.Parallel() + s := NewStreamStatus() + + for i := 0; i < 30; i++ { + s.RecordError(fmt.Sprintf("error_%d", i)) + } + + assert.Equal(t, maxStreamErrorEntries, len(s.Errors)) + assert.Equal(t, 30, s.TotalErrorCount()) +} + +func TestStreamStatus_RecordError_NilSafe(t *testing.T) { + t.Parallel() + var s *StreamStatus + s.RecordError("should not panic") +} + +func TestStreamStatus_RecordError_Concurrent(t *testing.T) { + t.Parallel() + s := NewStreamStatus() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + s.RecordError(fmt.Sprintf("error_%d", idx)) + }(i) + } + wg.Wait() + + assert.Equal(t, 100, s.TotalErrorCount()) + assert.LessOrEqual(t, len(s.Errors), maxStreamErrorEntries) +} + +func TestStreamStatus_HasErrors_Empty(t *testing.T) { + t.Parallel() + s := NewStreamStatus() + assert.False(t, s.HasErrors()) + assert.Equal(t, 0, s.TotalErrorCount()) +} + +func TestStreamStatus_HasErrors_NilSafe(t *testing.T) { + t.Parallel() + var s *StreamStatus + assert.False(t, s.HasErrors()) + assert.Equal(t, 0, s.TotalErrorCount()) +} + +func TestStreamStatus_IsNormalEnd(t *testing.T) { + t.Parallel() + tests := []struct { + reason StreamEndReason + normal bool + }{ + {StreamEndReasonDone, true}, + {StreamEndReasonEOF, true}, + {StreamEndReasonHandlerStop, true}, + {StreamEndReasonTimeout, false}, + {StreamEndReasonClientGone, false}, + {StreamEndReasonScannerErr, false}, + {StreamEndReasonPanic, false}, + {StreamEndReasonPingFail, false}, + {StreamEndReasonNone, false}, + } + for _, tt := range tests { + s := NewStreamStatus() + s.SetEndReason(tt.reason, nil) + assert.Equal(t, tt.normal, s.IsNormalEnd(), "reason=%s", tt.reason) + } +} + +func TestStreamStatus_IsNormalEnd_NilSafe(t *testing.T) { + t.Parallel() + var s *StreamStatus + assert.True(t, s.IsNormalEnd()) +} + +func TestStreamStatus_Summary(t *testing.T) { + t.Parallel() + + s := NewStreamStatus() + s.SetEndReason(StreamEndReasonDone, nil) + summary := s.Summary() + assert.Contains(t, summary, "reason=done") + assert.NotContains(t, summary, "soft_errors") + + s2 := NewStreamStatus() + s2.SetEndReason(StreamEndReasonTimeout, nil) + s2.RecordError("bad json") + s2.RecordError("write failed") + summary2 := s2.Summary() + assert.Contains(t, summary2, "reason=timeout") + assert.Contains(t, summary2, "soft_errors=2") +} + +func TestStreamStatus_Summary_NilSafe(t *testing.T) { + t.Parallel() + var s *StreamStatus + assert.Equal(t, "StreamStatus", s.Summary()) +} diff --git a/relay/helper/stream_result.go b/relay/helper/stream_result.go new file mode 100644 index 00000000..aa77e803 --- /dev/null +++ b/relay/helper/stream_result.go @@ -0,0 +1,52 @@ +package helper + +import ( + relaycommon "github.com/QuantumNous/new-api/relay/common" +) + +// StreamResult is passed to each dataHandler invocation, providing methods +// to record soft errors, signal fatal stops, or mark normal completion. +// StreamScannerHandler checks IsStopped() after each callback invocation. +type StreamResult struct { + status *relaycommon.StreamStatus + stopped bool +} + +func newStreamResult(status *relaycommon.StreamStatus) *StreamResult { + return &StreamResult{status: status} +} + +// Error records a soft error. The stream continues processing. +// Can be called multiple times per chunk. +func (r *StreamResult) Error(err error) { + if err == nil { + return + } + r.status.RecordError(err.Error()) +} + +// Stop records a fatal error and marks the stream to stop after this chunk. +func (r *StreamResult) Stop(err error) { + if err != nil { + r.status.RecordError(err.Error()) + } + r.status.SetEndReason(relaycommon.StreamEndReasonHandlerStop, err) + r.stopped = true +} + +// Done signals that the handler has finished processing normally +// (e.g., Dify "message_end"). The stream stops after this chunk. +func (r *StreamResult) Done() { + r.status.SetEndReason(relaycommon.StreamEndReasonDone, nil) + r.stopped = true +} + +// IsStopped returns whether Stop() or Done() was called during this chunk. +func (r *StreamResult) IsStopped() bool { + return r.stopped +} + +// reset clears the per-chunk stopped flag so the object can be reused. +func (r *StreamResult) reset() { + r.stopped = false +} diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index ae70f53c..a9bc5e16 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -34,12 +34,15 @@ func getScannerBufferSize() int { return DefaultMaxScannerBufferSize } -func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) { +func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) { if resp == nil || dataHandler == nil { return } + // 无条件新建 StreamStatus + info.StreamStatus = relaycommon.NewStreamStatus() + // 确保响应体总是被关闭 defer func() { if resp.Body != nil { @@ -121,6 +124,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon wg.Done() if r := recover(); r != nil { logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r)) + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("ping panic: %v", r)) common.SafeSendBool(stopChan, true) } if common.DebugEnabled { @@ -148,6 +152,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon case err := <-done: if err != nil { logger.LogError(c, "ping data error: "+err.Error()) + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, err) return } if common.DebugEnabled { @@ -155,6 +160,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon } case <-time.After(10 * time.Second): logger.LogError(c, "ping data send timeout") + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, fmt.Errorf("ping send timeout")) return case <-ctx.Done(): return @@ -184,14 +190,17 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon wg.Done() if r := recover(); r != nil { logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r)) + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("handler panic: %v", r)) } common.SafeSendBool(stopChan, true) }() + sr := newStreamResult(info.StreamStatus) for data := range dataChan { + sr.reset() writeMutex.Lock() - success := dataHandler(data) + dataHandler(data, sr) writeMutex.Unlock() - if !success { + if sr.IsStopped() { return } } @@ -205,6 +214,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon wg.Done() if r := recover(); r != nil { logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("scanner panic: %v", r)) } common.SafeSendBool(stopChan, true) if common.DebugEnabled { @@ -220,6 +230,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon case <-ctx.Done(): return case <-c.Request.Context().Done(): + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err()) return default: } @@ -253,7 +264,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon return } } else { - // done, 处理完成标志,直接退出停止读取剩余数据防止出错 + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil) if common.DebugEnabled { println("received [DONE], stopping scanner") } @@ -264,20 +275,25 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon if err := scanner.Err(); err != nil { if err != io.EOF { logger.LogError(c, "scanner error: "+err.Error()) + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, err) } } + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil) }) // 主循环等待完成或超时 select { case <-ticker.C: - // 超时处理逻辑 - logger.LogError(c, "streaming timeout") + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonTimeout, nil) case <-stopChan: - // 正常结束 - logger.LogInfo(c, "streaming finished") + // EndReason already set by the goroutine that triggered stopChan case <-c.Request.Context().Done(): - // 客户端断开连接 - logger.LogInfo(c, "client disconnected") + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err()) + } + + if info.StreamStatus.IsNormalEnd() && !info.StreamStatus.HasErrors() { + logger.LogInfo(c, fmt.Sprintf("stream ended: %s", info.StreamStatus.Summary())) + } else { + logger.LogError(c, fmt.Sprintf("stream ended: %s, received=%d", info.StreamStatus.Summary(), info.ReceivedResponseCount)) } } diff --git a/relay/helper/stream_scanner_test.go b/relay/helper/stream_scanner_test.go index 6890d82a..9d6f3bb4 100644 --- a/relay/helper/stream_scanner_test.go +++ b/relay/helper/stream_scanner_test.go @@ -56,8 +56,6 @@ func buildSSEBody(n int) string { return b.String() } -// slowReader wraps a reader and injects a delay before each Read call, -// simulating a slow upstream that trickles data. type slowReader struct { r io.Reader delay time.Duration @@ -79,7 +77,7 @@ func TestStreamScannerHandler_NilInputs(t *testing.T) { info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} - StreamScannerHandler(c, nil, info, func(data string) bool { return true }) + StreamScannerHandler(c, nil, info, func(data string, sr *StreamResult) {}) StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil) } @@ -89,9 +87,8 @@ func TestStreamScannerHandler_EmptyBody(t *testing.T) { c, resp, info := setupStreamTest(t, strings.NewReader("")) var called atomic.Bool - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { called.Store(true) - return true }) assert.False(t, called.Load(), "handler should not be called for empty body") @@ -105,9 +102,8 @@ func TestStreamScannerHandler_1000Chunks(t *testing.T) { c, resp, info := setupStreamTest(t, strings.NewReader(body)) var count atomic.Int64 - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) - return true }) assert.Equal(t, int64(numChunks), count.Load()) @@ -124,9 +120,8 @@ func TestStreamScannerHandler_10000Chunks(t *testing.T) { var count atomic.Int64 start := time.Now() - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) - return true }) elapsed := time.Since(start) @@ -145,11 +140,10 @@ func TestStreamScannerHandler_OrderPreserved(t *testing.T) { var mu sync.Mutex received := make([]string, 0, numChunks) - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { mu.Lock() received = append(received, data) mu.Unlock() - return true }) require.Equal(t, numChunks, len(received)) @@ -166,31 +160,32 @@ func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) { c, resp, info := setupStreamTest(t, strings.NewReader(body)) var count atomic.Int64 - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) - return true }) assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed") } -func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) { +func TestStreamScannerHandler_StopStopsStream(t *testing.T) { t.Parallel() const numChunks = 200 body := buildSSEBody(numChunks) c, resp, info := setupStreamTest(t, strings.NewReader(body)) - const failAt = 50 + const stopAt int64 = 50 var count atomic.Int64 - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { n := count.Add(1) - return n < failAt + if n >= stopAt { + sr.Stop(fmt.Errorf("fatal at %d", n)) + } }) - // The worker stops at failAt; the scanner may have read ahead, - // but the handler should not be called beyond failAt. - assert.Equal(t, int64(failAt), count.Load()) + assert.Equal(t, stopAt, count.Load()) + require.NotNil(t, info.StreamStatus) + assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason) } func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) { @@ -210,9 +205,8 @@ func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) { c, resp, info := setupStreamTest(t, strings.NewReader(b.String())) var count atomic.Int64 - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) - return true }) assert.Equal(t, int64(100), count.Load()) @@ -225,25 +219,18 @@ func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) { c, resp, info := setupStreamTest(t, strings.NewReader(body)) var got string - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { got = data - return true }) assert.Equal(t, "{\"trimmed\":true}", got) } -// ---------- Decoupling: scanner not blocked by slow handler ---------- +// ---------- Decoupling ---------- func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) { t.Parallel() - // Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk). - // If the scanner were synchronously coupled to the handler, total time would be - // ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms. - // With decoupling, total time should be closer to - // ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms - // because the scanner reads ahead into the buffer while the handler processes. const numChunks = 50 const upstreamDelay = 10 * time.Millisecond const handlerDelay = 20 * time.Millisecond @@ -273,10 +260,9 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) { start := time.Now() done := make(chan struct{}) go func() { - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { time.Sleep(handlerDelay) count.Add(1) - return true }) close(done) }() @@ -293,7 +279,6 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) { coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay) t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime) - // If decoupled, elapsed should be well under the coupled estimate. assert.Less(t, elapsed, coupledTime*85/100, "decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime) } @@ -311,9 +296,8 @@ func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) { done := make(chan struct{}) go func() { - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) - return true }) close(done) }() @@ -344,8 +328,6 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) { setting.PingIntervalSeconds = oldSeconds }) - // Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds. - // The ping interval is 1s, so we should see at least 2 pings. pr, pw := io.Pipe() go func() { defer pw.Close() @@ -372,9 +354,8 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) { var count atomic.Int64 done := make(chan struct{}) go func() { - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) - return true }) close(done) }() @@ -436,9 +417,8 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) { var count atomic.Int64 done := make(chan struct{}) go func() { - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) - return true }) close(done) }() @@ -456,6 +436,199 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) { assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true") } +// ---------- StreamStatus integration ---------- + +func TestStreamScannerHandler_StreamStatus_DoneReason(t *testing.T) { + t.Parallel() + + body := buildSSEBody(10) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) + + require.NotNil(t, info.StreamStatus) + assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) + assert.Nil(t, info.StreamStatus.EndError) + assert.True(t, info.StreamStatus.IsNormalEnd()) + assert.False(t, info.StreamStatus.HasErrors()) +} + +func TestStreamScannerHandler_StreamStatus_EOFWithoutDone(t *testing.T) { + t.Parallel() + + var b strings.Builder + for i := 0; i < 5; i++ { + fmt.Fprintf(&b, "data: {\"id\":%d}\n", i) + } + c, resp, info := setupStreamTest(t, strings.NewReader(b.String())) + + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) + + require.NotNil(t, info.StreamStatus) + assert.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason) + assert.True(t, info.StreamStatus.IsNormalEnd()) +} + +func TestStreamScannerHandler_StreamStatus_HandlerStop(t *testing.T) { + t.Parallel() + + body := buildSSEBody(100) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { + n := count.Add(1) + if n >= 10 { + sr.Stop(fmt.Errorf("stop at 10")) + } + }) + + require.NotNil(t, info.StreamStatus) + assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason) + assert.True(t, info.StreamStatus.HasErrors()) +} + +func TestStreamScannerHandler_StreamStatus_HandlerDone(t *testing.T) { + t.Parallel() + + body := buildSSEBody(20) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { + n := count.Add(1) + if n >= 5 { + sr.Done() + } + }) + + assert.Equal(t, int64(5), count.Load()) + require.NotNil(t, info.StreamStatus) + assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) + assert.False(t, info.StreamStatus.HasErrors()) +} + +func TestStreamScannerHandler_StreamStatus_Timeout(t *testing.T) { + // Not parallel: modifies global constant.StreamingTimeout + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 2 + t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) + + pr, pw := io.Pipe() + go func() { + fmt.Fprint(pw, "data: {\"id\":1}\n") + time.Sleep(10 * time.Second) + pw.Close() + }() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + resp := &http.Response{Body: pr} + info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} + + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for stream timeout") + } + + require.NotNil(t, info.StreamStatus) + assert.Equal(t, relaycommon.StreamEndReasonTimeout, info.StreamStatus.EndReason) + assert.False(t, info.StreamStatus.IsNormalEnd()) +} + +func TestStreamScannerHandler_StreamStatus_SoftErrors(t *testing.T) { + t.Parallel() + + body := buildSSEBody(10) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { + sr.Error(fmt.Errorf("soft error for chunk")) + }) + + require.NotNil(t, info.StreamStatus) + assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) + assert.True(t, info.StreamStatus.HasErrors()) + assert.Equal(t, 10, info.StreamStatus.TotalErrorCount()) +} + +func TestStreamScannerHandler_StreamStatus_MultipleErrorsPerChunk(t *testing.T) { + t.Parallel() + + body := buildSSEBody(5) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { + sr.Error(fmt.Errorf("error A")) + sr.Error(fmt.Errorf("error B")) + }) + + require.NotNil(t, info.StreamStatus) + assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) + assert.Equal(t, 10, info.StreamStatus.TotalErrorCount()) +} + +func TestStreamScannerHandler_StreamStatus_ErrorThenStop(t *testing.T) { + t.Parallel() + + // Use a large body without [DONE] to avoid race between scanner's [DONE] + // and handler's Stop on the sync.Once EndReason. + var b strings.Builder + for i := 0; i < 100; i++ { + fmt.Fprintf(&b, "data: {\"id\":%d}\n", i) + } + c, resp, info := setupStreamTest(t, strings.NewReader(b.String())) + + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { + count.Add(1) + sr.Error(fmt.Errorf("soft error")) + sr.Stop(fmt.Errorf("fatal")) + }) + + assert.Equal(t, int64(1), count.Load()) + require.NotNil(t, info.StreamStatus) + assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason) + assert.Equal(t, 2, info.StreamStatus.TotalErrorCount()) +} + +func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) { + t.Parallel() + + body := buildSSEBody(1) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + assert.Nil(t, info.StreamStatus) + + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) + + assert.NotNil(t, info.StreamStatus) +} + +func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) { + t.Parallel() + + body := buildSSEBody(5) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + info.StreamStatus = relaycommon.NewStreamStatus() + info.StreamStatus.RecordError("pre-existing error") + + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) + + assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) + assert.Equal(t, 1, info.StreamStatus.TotalErrorCount()) +} + func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) { t.Parallel() @@ -469,9 +642,6 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) { setting.PingIntervalSeconds = oldSeconds }) - // Slow upstream + slow handler. Total stream takes ~5 seconds. - // The ping goroutine stays alive as long as the scanner is reading, - // so pings should fire between data writes. pr, pw := io.Pipe() go func() { defer pw.Close() @@ -498,9 +668,8 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) { var count atomic.Int64 done := make(chan struct{}) go func() { - StreamScannerHandler(c, resp, info, func(data string) bool { + StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) - return true }) close(done) }() diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 373e32d6..75e6fb1d 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -76,6 +76,7 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m appendFinalRequestFormat(relayInfo, other) appendBillingInfo(relayInfo, other) appendParamOverrideInfo(relayInfo, other) + appendStreamStatus(relayInfo, other) return other } @@ -86,6 +87,33 @@ func appendParamOverrideInfo(relayInfo *relaycommon.RelayInfo, other map[string] other["po"] = relayInfo.ParamOverrideAudit } +func appendStreamStatus(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) { + if relayInfo == nil || other == nil || !relayInfo.IsStream || relayInfo.StreamStatus == nil { + return + } + ss := relayInfo.StreamStatus + status := "ok" + if !ss.IsNormalEnd() || ss.HasErrors() { + status = "error" + } + streamInfo := map[string]interface{}{ + "status": status, + "end_reason": string(ss.EndReason), + } + if ss.EndError != nil { + streamInfo["end_error"] = ss.EndError.Error() + } + if ss.ErrorCount > 0 { + streamInfo["error_count"] = ss.ErrorCount + messages := make([]string, 0, len(ss.Errors)) + for _, e := range ss.Errors { + messages = append(messages, e.Message) + } + streamInfo["errors"] = messages + } + other["stream_status"] = streamInfo +} + func appendBillingInfo(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) { if relayInfo == nil || other == nil { return diff --git a/web/src/hooks/usage-logs/useUsageLogsData.jsx b/web/src/hooks/usage-logs/useUsageLogsData.jsx index d4ac9df4..a9ffaba0 100644 --- a/web/src/hooks/usage-logs/useUsageLogsData.jsx +++ b/web/src/hooks/usage-logs/useUsageLogsData.jsx @@ -601,6 +601,32 @@ export const useLogsData = () => { value: other.request_path, }); } + if (isAdminUser && other?.stream_status) { + const ss = other.stream_status; + const isOk = ss.status === 'ok'; + const statusLabel = isOk ? '✓ ' + t('正常') : '✗ ' + t('异常'); + let streamValue = statusLabel + ' (' + (ss.end_reason || 'unknown') + ')'; + if (ss.error_count > 0) { + streamValue += ` [${t('软错误')}: ${ss.error_count}]`; + } + if (ss.end_error) { + streamValue += ` - ${ss.end_error}`; + } + expandDataLocal.push({ + key: t('流状态'), + value: streamValue, + }); + if (Array.isArray(ss.errors) && ss.errors.length > 0) { + expandDataLocal.push({ + key: t('流错误详情'), + value: ( +
+ {ss.errors.join('\n')} +
+ ), + }); + } + } if (Array.isArray(other?.po) && other.po.length > 0) { expandDataLocal.push({ key: t('参数覆盖'), diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index e392379e..e41eafdc 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -2678,6 +2678,11 @@ "请求结束后多退少补": "Adjust after request completion", "请求超时,请刷新页面后重新发起 GitHub 登录": "Request timed out, please refresh and restart GitHub login", "请求路径": "Request path", + "流状态": "Stream Status", + "流错误详情": "Stream Error Details", + "软错误": "soft errors", + "正常": "Normal", + "异常": "Abnormal", "请求转换": "Request conversion", "请求预扣费额度": "Pre-deduction quota for requests", "请点击我": "Please click me",