5238f279db
- Add StreamStatus type (relay/common) to track stream end reason (done/timeout/client_gone/scanner_error/eof/panic/ping_fail) and accumulate soft errors during streaming via sync.Once + sync.Mutex. - Add StreamResult (relay/helper) as the callback interface: adapters call sr.Error() for soft errors, sr.Stop() for fatal, sr.Done() for normal completion. No early-return problem — multiple errors per chunk are naturally supported. - Refactor StreamScannerHandler callback from func(string) bool to func(string, *StreamResult). All 9 channel adapters updated. - Write stream_status into log other JSON field (admin-only) with status ok/error, end_reason, error_count, and error messages. - Frontend: display stream status in log detail expansion for admins.
300 lines
7.7 KiB
Go
300 lines
7.7 KiB
Go
package helper
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/constant"
|
|
"github.com/QuantumNous/new-api/logger"
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
"github.com/QuantumNous/new-api/setting/operation_setting"
|
|
|
|
"github.com/bytedance/gopkg/util/gopool"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
const (
|
|
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
|
|
DefaultMaxScannerBufferSize = 64 << 20 // 64MB (64*1024*1024) default SSE buffer size
|
|
DefaultPingInterval = 10 * time.Second
|
|
)
|
|
|
|
func getScannerBufferSize() int {
|
|
if constant.StreamScannerMaxBufferMB > 0 {
|
|
return constant.StreamScannerMaxBufferMB << 20
|
|
}
|
|
return DefaultMaxScannerBufferSize
|
|
}
|
|
|
|
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) {
|
|
|
|
if resp == nil || dataHandler == nil {
|
|
return
|
|
}
|
|
|
|
// 无条件新建 StreamStatus
|
|
info.StreamStatus = relaycommon.NewStreamStatus()
|
|
|
|
// 确保响应体总是被关闭
|
|
defer func() {
|
|
if resp.Body != nil {
|
|
resp.Body.Close()
|
|
}
|
|
}()
|
|
|
|
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
|
|
|
var (
|
|
stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
|
|
scanner = bufio.NewScanner(resp.Body)
|
|
ticker = time.NewTicker(streamingTimeout)
|
|
pingTicker *time.Ticker
|
|
writeMutex sync.Mutex // Mutex to protect concurrent writes
|
|
wg sync.WaitGroup // 用于等待所有 goroutine 退出
|
|
)
|
|
|
|
generalSettings := operation_setting.GetGeneralSetting()
|
|
pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing
|
|
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
|
if pingInterval <= 0 {
|
|
pingInterval = DefaultPingInterval
|
|
}
|
|
|
|
if pingEnabled {
|
|
pingTicker = time.NewTicker(pingInterval)
|
|
}
|
|
|
|
if common.DebugEnabled {
|
|
// print timeout and ping interval for debugging
|
|
println("relay timeout seconds:", common.RelayTimeout)
|
|
println("relay max idle conns:", common.RelayMaxIdleConns)
|
|
println("relay max idle conns per host:", common.RelayMaxIdleConnsPerHost)
|
|
println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
|
|
println("ping interval seconds:", int64(pingInterval.Seconds()))
|
|
}
|
|
|
|
// 改进资源清理,确保所有 goroutine 正确退出
|
|
defer func() {
|
|
// 通知所有 goroutine 停止
|
|
common.SafeSendBool(stopChan, true)
|
|
|
|
ticker.Stop()
|
|
if pingTicker != nil {
|
|
pingTicker.Stop()
|
|
}
|
|
|
|
// 等待所有 goroutine 退出,最多等待5秒
|
|
done := make(chan struct{})
|
|
gopool.Go(func() {
|
|
wg.Wait()
|
|
close(done)
|
|
})
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(5 * time.Second):
|
|
logger.LogError(c, "timeout waiting for goroutines to exit")
|
|
}
|
|
|
|
close(stopChan)
|
|
}()
|
|
|
|
scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize())
|
|
scanner.Split(bufio.ScanLines)
|
|
SetEventStreamHeaders(c)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
ctx = context.WithValue(ctx, "stop_chan", stopChan)
|
|
|
|
// Handle ping data sending with improved error handling
|
|
if pingEnabled && pingTicker != nil {
|
|
wg.Add(1)
|
|
gopool.Go(func() {
|
|
defer func() {
|
|
wg.Done()
|
|
if r := recover(); r != nil {
|
|
logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("ping panic: %v", r))
|
|
common.SafeSendBool(stopChan, true)
|
|
}
|
|
if common.DebugEnabled {
|
|
println("ping goroutine exited")
|
|
}
|
|
}()
|
|
|
|
// 添加超时保护,防止 goroutine 无限运行
|
|
maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
|
|
pingTimeout := time.NewTimer(maxPingDuration)
|
|
defer pingTimeout.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-pingTicker.C:
|
|
// 使用超时机制防止写操作阻塞
|
|
done := make(chan error, 1)
|
|
gopool.Go(func() {
|
|
writeMutex.Lock()
|
|
defer writeMutex.Unlock()
|
|
done <- PingData(c)
|
|
})
|
|
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
logger.LogError(c, "ping data error: "+err.Error())
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, err)
|
|
return
|
|
}
|
|
if common.DebugEnabled {
|
|
println("ping data sent")
|
|
}
|
|
case <-time.After(10 * time.Second):
|
|
logger.LogError(c, "ping data send timeout")
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, fmt.Errorf("ping send timeout"))
|
|
return
|
|
case <-ctx.Done():
|
|
return
|
|
case <-stopChan:
|
|
return
|
|
}
|
|
case <-ctx.Done():
|
|
return
|
|
case <-stopChan:
|
|
return
|
|
case <-c.Request.Context().Done():
|
|
// 监听客户端断开连接
|
|
return
|
|
case <-pingTimeout.C:
|
|
logger.LogError(c, "ping goroutine max duration reached")
|
|
return
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
dataChan := make(chan string, 10)
|
|
|
|
wg.Add(1)
|
|
gopool.Go(func() {
|
|
defer func() {
|
|
wg.Done()
|
|
if r := recover(); r != nil {
|
|
logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("handler panic: %v", r))
|
|
}
|
|
common.SafeSendBool(stopChan, true)
|
|
}()
|
|
sr := newStreamResult(info.StreamStatus)
|
|
for data := range dataChan {
|
|
sr.reset()
|
|
writeMutex.Lock()
|
|
dataHandler(data, sr)
|
|
writeMutex.Unlock()
|
|
if sr.IsStopped() {
|
|
return
|
|
}
|
|
}
|
|
})
|
|
|
|
// Scanner goroutine with improved error handling
|
|
wg.Add(1)
|
|
common.RelayCtxGo(ctx, func() {
|
|
defer func() {
|
|
close(dataChan)
|
|
wg.Done()
|
|
if r := recover(); r != nil {
|
|
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("scanner panic: %v", r))
|
|
}
|
|
common.SafeSendBool(stopChan, true)
|
|
if common.DebugEnabled {
|
|
println("scanner goroutine exited")
|
|
}
|
|
}()
|
|
|
|
for scanner.Scan() {
|
|
// 检查是否需要停止
|
|
select {
|
|
case <-stopChan:
|
|
return
|
|
case <-ctx.Done():
|
|
return
|
|
case <-c.Request.Context().Done():
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
|
|
return
|
|
default:
|
|
}
|
|
|
|
ticker.Reset(streamingTimeout)
|
|
data := scanner.Text()
|
|
if common.DebugEnabled {
|
|
println(data)
|
|
}
|
|
|
|
if len(data) < 6 {
|
|
continue
|
|
}
|
|
if data[:5] != "data:" && data[:6] != "[DONE]" {
|
|
continue
|
|
}
|
|
data = data[5:]
|
|
data = strings.TrimSpace(data)
|
|
if data == "" {
|
|
continue
|
|
}
|
|
if !strings.HasPrefix(data, "[DONE]") {
|
|
info.SetFirstResponseTime()
|
|
info.ReceivedResponseCount++
|
|
|
|
select {
|
|
case dataChan <- data:
|
|
case <-ctx.Done():
|
|
return
|
|
case <-stopChan:
|
|
return
|
|
}
|
|
} else {
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
|
if common.DebugEnabled {
|
|
println("received [DONE], stopping scanner")
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
if err != io.EOF {
|
|
logger.LogError(c, "scanner error: "+err.Error())
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, err)
|
|
}
|
|
}
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
|
|
})
|
|
|
|
// 主循环等待完成或超时
|
|
select {
|
|
case <-ticker.C:
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonTimeout, nil)
|
|
case <-stopChan:
|
|
// EndReason already set by the goroutine that triggered stopChan
|
|
case <-c.Request.Context().Done():
|
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
|
|
}
|
|
|
|
if info.StreamStatus.IsNormalEnd() && !info.StreamStatus.HasErrors() {
|
|
logger.LogInfo(c, fmt.Sprintf("stream ended: %s", info.StreamStatus.Summary()))
|
|
} else {
|
|
logger.LogError(c, fmt.Sprintf("stream ended: %s, received=%d", info.StreamStatus.Summary(), info.ReceivedResponseCount))
|
|
}
|
|
}
|