refactor(file-source): unify file source creation and enhance caching mechanisms

This commit is contained in:
CaIon
2026-04-06 15:54:42 +08:00
parent 8fc0eb78e2
commit 03758a4a85
10 changed files with 292 additions and 386 deletions
+18 -24
View File
@@ -98,6 +98,20 @@ func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
return mediaContent return mediaContent
} }
func (m *ClaudeMediaMessage) ToFileSource() types.FileSource {
if m.Source == nil {
return nil
}
data := m.Source.Url
if data == "" {
data = common.Interface2String(m.Source.Data)
}
if data == "" {
return nil
}
return types.NewFileSourceFromData(data, m.Source.MediaType)
}
type ClaudeMessageSource struct { type ClaudeMessageSource struct {
Type string `json:"type"` Type string `json:"type"`
MediaType string `json:"media_type,omitempty"` MediaType string `json:"media_type,omitempty"`
@@ -223,14 +237,6 @@ type OutputConfigForEffort struct {
Effort string `json:"effort,omitempty"` Effort string `json:"effort,omitempty"`
} }
// createClaudeFileSource 根据数据内容创建正确类型的 FileSource
func createClaudeFileSource(data string) *types.FileSource {
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
return types.NewURLFileSource(data)
}
return types.NewBase64FileSource(data, "")
}
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
maxTokens := 0 maxTokens := 0
if c.MaxTokens != nil { if c.MaxTokens != nil {
@@ -258,22 +264,16 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
case "text": case "text":
texts = append(texts, media.GetText()) texts = append(texts, media.GetText())
case "image": case "image":
if media.Source != nil { if source := media.ToFileSource(); source != nil {
data := media.Source.Url
if data == "" {
data = common.Interface2String(media.Source.Data)
}
if data != "" {
fileMeta = append(fileMeta, &types.FileMeta{ fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeImage, FileType: types.FileTypeImage,
Source: createClaudeFileSource(data), Source: source,
}) })
} }
} }
} }
} }
} }
}
// messages // messages
for _, message := range c.Messages { for _, message := range c.Messages {
@@ -293,18 +293,12 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
case "text": case "text":
texts = append(texts, media.GetText()) texts = append(texts, media.GetText())
case "image": case "image":
if media.Source != nil { if source := media.ToFileSource(); source != nil {
data := media.Source.Url
if data == "" {
data = common.Interface2String(media.Source.Data)
}
if data != "" {
fileMeta = append(fileMeta, &types.FileMeta{ fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeImage, FileType: types.FileTypeImage,
Source: createClaudeFileSource(data), Source: source,
}) })
} }
}
case "tool_use": case "tool_use":
if media.Name != "" { if media.Name != "" {
texts = append(texts, media.Name) texts = append(texts, media.Name)
+8 -11
View File
@@ -64,14 +64,6 @@ type LatLng struct {
Longitude *float64 `json:"longitude,omitempty"` Longitude *float64 `json:"longitude,omitempty"`
} }
// createGeminiFileSource 根据数据内容创建正确类型的 FileSource
func createGeminiFileSource(data string, mimeType string) *types.FileSource {
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
return types.NewURLFileSource(data)
}
return types.NewBase64FileSource(data, mimeType)
}
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
var files []*types.FileMeta = make([]*types.FileMeta, 0) var files []*types.FileMeta = make([]*types.FileMeta, 0)
@@ -87,9 +79,8 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
if part.Text != "" { if part.Text != "" {
inputTexts = append(inputTexts, part.Text) inputTexts = append(inputTexts, part.Text)
} }
if part.InlineData != nil && part.InlineData.Data != "" { if source := part.InlineData.ToFileSource(); source != nil {
mimeType := part.InlineData.MimeType mimeType := part.InlineData.MimeType
source := createGeminiFileSource(part.InlineData.Data, mimeType)
var fileType types.FileType var fileType types.FileType
if strings.HasPrefix(mimeType, "image/") { if strings.HasPrefix(mimeType, "image/") {
fileType = types.FileTypeImage fileType = types.FileTypeImage
@@ -103,7 +94,6 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
files = append(files, &types.FileMeta{ files = append(files, &types.FileMeta{
FileType: fileType, FileType: fileType,
Source: source, Source: source,
MimeType: mimeType,
}) })
} }
} }
@@ -215,6 +205,13 @@ type GeminiInlineData struct {
Data string `json:"data"` Data string `json:"data"`
} }
func (d *GeminiInlineData) ToFileSource() types.FileSource {
if d == nil || d.Data == "" {
return nil
}
return types.NewFileSourceFromData(d.Data, d.MimeType)
}
// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType // UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
func (g *GeminiInlineData) UnmarshalJSON(data []byte) error { func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
type Alias GeminiInlineData // Use type alias to avoid recursion type Alias GeminiInlineData // Use type alias to avoid recursion
+52 -46
View File
@@ -108,14 +108,6 @@ type GeneralOpenAIRequest struct {
ReasoningSplit json.RawMessage `json:"reasoning_split,omitempty"` ReasoningSplit json.RawMessage `json:"reasoning_split,omitempty"`
} }
// createFileSource 根据数据内容创建正确类型的 FileSource
func createFileSource(data string) *types.FileSource {
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
return types.NewURLFileSource(data)
}
return types.NewBase64FileSource(data, "")
}
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
var tokenCountMeta types.TokenCountMeta var tokenCountMeta types.TokenCountMeta
var texts = make([]string, 0) var texts = make([]string, 0)
@@ -159,44 +151,24 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
} }
arrayContent := message.ParseContent() arrayContent := message.ParseContent()
for _, m := range arrayContent { for _, m := range arrayContent {
if m.Type == ContentTypeImageURL { source := m.ToFileSource()
imageUrl := m.GetImageMedia() if source != nil {
if imageUrl != nil && imageUrl.Url != "" { meta := &types.FileMeta{Source: source}
source := createFileSource(imageUrl.Url) switch m.Type {
fileMeta = append(fileMeta, &types.FileMeta{ case ContentTypeImageURL:
FileType: types.FileTypeImage, meta.FileType = types.FileTypeImage
Source: source, if img := m.GetImageMedia(); img != nil {
Detail: imageUrl.Detail, meta.Detail = img.Detail
})
} }
} else if m.Type == ContentTypeInputAudio { case ContentTypeInputAudio:
inputAudio := m.GetInputAudio() meta.FileType = types.FileTypeAudio
if inputAudio != nil && inputAudio.Data != "" { case ContentTypeFile:
source := createFileSource(inputAudio.Data) meta.FileType = types.FileTypeFile
fileMeta = append(fileMeta, &types.FileMeta{ case ContentTypeVideoUrl:
FileType: types.FileTypeAudio, meta.FileType = types.FileTypeVideo
Source: source,
})
} }
} else if m.Type == ContentTypeFile { fileMeta = append(fileMeta, meta)
file := m.GetFile() } else if m.Type == ContentTypeText {
if file != nil && file.FileData != "" {
source := createFileSource(file.FileData)
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeFile,
Source: source,
})
}
} else if m.Type == ContentTypeVideoUrl {
videoUrl := m.GetVideoUrl()
if videoUrl != nil && videoUrl.Url != "" {
source := createFileSource(videoUrl.Url)
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeVideo,
Source: source,
})
}
} else {
texts = append(texts, m.Text) texts = append(texts, m.Text)
} }
} }
@@ -391,6 +363,40 @@ func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
return nil return nil
} }
func (m *MediaContent) ToFileSource() types.FileSource {
switch m.Type {
case ContentTypeImageURL:
img := m.GetImageMedia()
if img == nil || img.Url == "" {
return nil
}
return types.NewFileSourceFromData(img.Url, img.MimeType)
case ContentTypeInputAudio:
audio := m.GetInputAudio()
if audio == nil || audio.Data == "" {
return nil
}
mimeType := ""
if audio.Format != "" {
mimeType = "audio/" + audio.Format
}
return types.NewFileSourceFromData(audio.Data, mimeType)
case ContentTypeFile:
file := m.GetFile()
if file == nil || file.FileData == "" {
return nil
}
return types.NewFileSourceFromData(file.FileData, "")
case ContentTypeVideoUrl:
video := m.GetVideoUrl()
if video == nil || video.Url == "" {
return nil
}
return types.NewFileSourceFromData(video.Url, "")
}
return nil
}
type MessageImageUrl struct { type MessageImageUrl struct {
Url string `json:"url"` Url string `json:"url"`
Detail string `json:"detail,omitempty"` Detail string `json:"detail,omitempty"`
@@ -865,7 +871,7 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
if input.ImageUrl != "" { if input.ImageUrl != "" {
fileMeta = append(fileMeta, &types.FileMeta{ fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeImage, FileType: types.FileTypeImage,
Source: createFileSource(input.ImageUrl), Source: types.NewFileSourceFromData(input.ImageUrl, ""),
Detail: input.Detail, Detail: input.Detail,
}) })
} }
@@ -873,7 +879,7 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
if input.FileUrl != "" { if input.FileUrl != "" {
fileMeta = append(fileMeta, &types.FileMeta{ fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeFile, FileType: types.FileTypeFile,
Source: createFileSource(input.FileUrl), Source: types.NewFileSourceFromData(input.FileUrl, ""),
}) })
} }
} else { } else {
+16 -84
View File
@@ -1,12 +1,10 @@
package claude package claude
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"path/filepath"
"strings" "strings"
"github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/common"
@@ -46,61 +44,6 @@ func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
} }
} }
func createClaudeFileSource(file *dto.MessageFile) *types.FileSource {
if file == nil || file.FileData == "" {
return nil
}
if strings.HasPrefix(file.FileData, "http://") || strings.HasPrefix(file.FileData, "https://") {
return types.NewURLFileSource(file.FileData)
}
mimeType := ""
if ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(file.FileName)), "."); ext != "" {
if detected := service.GetMimeTypeByExtension(ext); detected != "application/octet-stream" {
mimeType = detected
}
}
return types.NewBase64FileSource(file.FileData, mimeType)
}
func buildClaudeFileMessage(c *gin.Context, file *dto.MessageFile) (*dto.ClaudeMediaMessage, error) {
source := createClaudeFileSource(file)
if source == nil {
return nil, nil
}
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting document for Claude")
if err != nil {
return nil, fmt.Errorf("get file data failed: %w", err)
}
switch strings.ToLower(mimeType) {
case "application/pdf":
return &dto.ClaudeMediaMessage{
Type: "document",
Source: &dto.ClaudeMessageSource{
Type: "base64",
MediaType: mimeType,
Data: base64Data,
},
}, nil
case "text/plain":
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
return nil, fmt.Errorf("decode text file data failed: %w", err)
}
return &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer(string(decodedData)),
}, nil
default:
msg := fmt.Sprintf("claude: skip unsupported file content, filename=%q, mime=%q", file.FileName, mimeType)
if c != nil {
logger.LogInfo(c, msg)
} else {
common.SysLog(msg)
}
return nil, nil
}
}
func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) { func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
claudeTools := make([]any, 0, len(textRequest.Tools)) claudeTools := make([]any, 0, len(textRequest.Tools))
@@ -142,7 +85,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
// 解析 UserLocation JSON // 解析 UserLocation JSON
var userLocationMap map[string]interface{} var userLocationMap map[string]interface{}
if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil { if err := common.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
// 检查是否有 approximate 字段 // 检查是否有 approximate 字段
if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok { if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" { if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
@@ -406,44 +349,33 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
Type: "text", Type: "text",
Text: common.GetPointer[string](mediaMessage.Text), Text: common.GetPointer[string](mediaMessage.Text),
}) })
case dto.ContentTypeImageURL: default:
claudeMediaMessage := dto.ClaudeMediaMessage{ source := mediaMessage.ToFileSource()
Type: "image", if source == nil {
Source: &dto.ClaudeMessageSource{
Type: "base64",
},
}
imageUrl := mediaMessage.GetImageMedia()
if imageUrl == nil {
continue continue
} }
// 使用统一的文件服务获取图片数据
var source *types.FileSource
if strings.HasPrefix(imageUrl.Url, "http") {
source = types.NewURLFileSource(imageUrl.Url)
} else {
source = types.NewBase64FileSource(imageUrl.Url, "")
}
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude") base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
if err != nil { if err != nil {
return nil, fmt.Errorf("get file data failed: %s", err.Error()) return nil, fmt.Errorf("get file data failed: %s", err.Error())
} }
claudeMediaMessage := dto.ClaudeMediaMessage{
Source: &dto.ClaudeMessageSource{
Type: "base64",
},
}
if strings.HasPrefix(mimeType, "application/pdf") {
claudeMediaMessage.Type = "document"
} else {
claudeMediaMessage.Type = "image"
}
claudeMediaMessage.Source.MediaType = mimeType claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = base64Data claudeMediaMessage.Source.Data = base64Data
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
// FIXME
//case dto.ContentTypeFile:
// claudeFileMessage, err := buildClaudeFileMessage(c, mediaMessage.GetFile())
// if err != nil {
// return nil, err
// }
// if claudeFileMessage != nil {
// claudeMediaMessages = append(claudeMediaMessages, *claudeFileMessage)
// }
default:
continue continue
} }
} }
if message.ToolCalls != nil { if message.ToolCalls != nil {
for _, toolCall := range message.ParseToolCalls() { for _, toolCall := range message.ParseToolCalls() {
inputObj := make(map[string]any) inputObj := make(map[string]any)
+3 -37
View File
@@ -585,14 +585,10 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
Text: part.Text, Text: part.Text,
}) })
} }
} else if part.Type == dto.ContentTypeImageURL {
// 使用统一的文件服务获取图片数据
var source *types.FileSource
imageUrl := part.GetImageMedia().Url
if strings.HasPrefix(imageUrl, "http") {
source = types.NewURLFileSource(imageUrl)
} else { } else {
source = types.NewBase64FileSource(imageUrl, "") source := part.ToFileSource()
if source == nil {
continue
} }
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini") base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini")
if err != nil { if err != nil {
@@ -604,36 +600,6 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList()) return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList())
} }
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
} else if part.Type == dto.ContentTypeFile {
if part.GetFile().FileId != "" {
return nil, fmt.Errorf("only base64 file is supported in gemini")
}
fileSource := types.NewBase64FileSource(part.GetFile().FileData, "")
base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini")
if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
}
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
} else if part.Type == dto.ContentTypeInputAudio {
if part.GetInputAudio().Data == "" {
return nil, fmt.Errorf("only base64 audio is supported in gemini")
}
audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format)
base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini")
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
parts = append(parts, dto.GeminiPart{ parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{ InlineData: &dto.GeminiInlineData{
MimeType: mimeType, MimeType: mimeType,
+2 -9
View File
@@ -98,15 +98,8 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
parts := m.ParseContent() parts := m.ParseContent()
for _, part := range parts { for _, part := range parts {
if part.Type == dto.ContentTypeImageURL { if part.Type == dto.ContentTypeImageURL {
img := part.GetImageMedia() source := part.ToFileSource()
if img != nil && img.Url != "" { if source != nil {
// 使用统一的文件服务获取图片数据
var source *types.FileSource
if strings.HasPrefix(img.Url, "http") {
source = types.NewURLFileSource(img.Url)
} else {
source = types.NewBase64FileSource(img.Url, "")
}
base64Data, _, err := service.GetBase64Data(c, source, "fetch image for ollama chat") base64Data, _, err := service.GetBase64Data(c, source, "fetch image for ollama chat")
if err != nil { if err != nil {
return nil, err return nil, err
+50 -29
View File
@@ -25,14 +25,26 @@ import (
// FileService 统一的文件处理服务 // FileService 统一的文件处理服务
// 提供文件下载、解码、缓存等功能的统一入口 // 提供文件下载、解码、缓存等功能的统一入口
// getContextCacheKey 生成 context 缓存的 key // getContextCacheKey 生成 URL context 缓存的 key
func getContextCacheKey(url string) string { func getContextCacheKey(url string) string {
return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url)) return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
} }
// getBase64ContextCacheKey 生成 base64 context 缓存的 key
// 使用 length + MIME + 前 128 字符作为输入,避免对整个 base64 数据做 hash
func getBase64ContextCacheKey(data string, mimeType string) string {
keyMaterial := fmt.Sprintf("%d:%s:", len(data), mimeType)
if len(data) > 128 {
keyMaterial += data[:128]
} else {
keyMaterial += data
}
return fmt.Sprintf("b64_cache_%s", common.GenerateHMAC(keyMaterial))
}
// LoadFileSource 加载文件源数据 // LoadFileSource 加载文件源数据
// 这是统一的入口,会自动处理缓存和不同的来源类型 // 这是统一的入口,会自动处理缓存和不同的来源类型
func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) { func LoadFileSource(c *gin.Context, source types.FileSource, reason ...string) (*types.CachedFileData, error) {
if source == nil { if source == nil {
return nil, fmt.Errorf("file source is nil") return nil, fmt.Errorf("file source is nil")
} }
@@ -43,7 +55,6 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string)
// 1. 快速检查内部缓存 // 1. 快速检查内部缓存
if source.HasCache() { if source.HasCache() {
// 即使命中内部缓存,也要确保注册到清理列表(如果尚未注册)
if c != nil { if c != nil {
registerSourceForCleanup(c, source) registerSourceForCleanup(c, source)
} }
@@ -62,39 +73,49 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string)
return source.GetCache(), nil return source.GetCache(), nil
} }
// 4. 如果是 URL,检查 Context 缓存 // 4. 根据来源类型加载(含 URL context 缓存查找)
var cachedData *types.CachedFileData
var contextKey string var contextKey string
if source.IsURL() && c != nil { var err error
contextKey = getContextCacheKey(source.URL)
if cachedData, exists := c.Get(contextKey); exists { switch s := source.(type) {
data := cachedData.(*types.CachedFileData) case *types.URLSource:
if c != nil {
contextKey = getContextCacheKey(s.URL)
if cached, exists := c.Get(contextKey); exists {
data := cached.(*types.CachedFileData)
source.SetCache(data) source.SetCache(data)
registerSourceForCleanup(c, source) registerSourceForCleanup(c, source)
return data, nil return data, nil
} }
} }
cachedData, err = loadFromURL(c, s.URL, reason...)
// 5. 执行加载逻辑 case *types.Base64Source:
var cachedData *types.CachedFileData if c != nil {
var err error contextKey = getBase64ContextCacheKey(s.Base64Data, s.MimeType)
if cached, exists := c.Get(contextKey); exists {
if source.IsURL() { data := cached.(*types.CachedFileData)
cachedData, err = loadFromURL(c, source.URL, reason...) source.SetCache(data)
} else { registerSourceForCleanup(c, source)
cachedData, err = loadFromBase64(source.Base64Data, source.MimeType) return data, nil
}
}
cachedData, err = loadFromBase64(s.Base64Data, s.MimeType)
default:
return nil, fmt.Errorf("unsupported file source type: %T", source)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 6. 设置缓存 // 5. 设置缓存
source.SetCache(cachedData) source.SetCache(cachedData)
if contextKey != "" && c != nil { if contextKey != "" && c != nil {
c.Set(contextKey, cachedData) c.Set(contextKey, cachedData)
} }
// 7. 注册到 context 以便请求结束时自动清理 // 6. 注册到 context 以便请求结束时自动清理
if c != nil { if c != nil {
registerSourceForCleanup(c, source) registerSourceForCleanup(c, source)
} }
@@ -103,15 +124,15 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string)
} }
// registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理 // registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
func registerSourceForCleanup(c *gin.Context, source *types.FileSource) { func registerSourceForCleanup(c *gin.Context, source types.FileSource) {
if source.IsRegistered() { if source.IsRegistered() {
return return
} }
key := string(constant.ContextKeyFileSourcesToCleanup) key := string(constant.ContextKeyFileSourcesToCleanup)
var sources []*types.FileSource var sources []types.FileSource
if existing, exists := c.Get(key); exists { if existing, exists := c.Get(key); exists {
sources = existing.([]*types.FileSource) sources = existing.([]types.FileSource)
} }
sources = append(sources, source) sources = append(sources, source)
c.Set(key, sources) c.Set(key, sources)
@@ -123,12 +144,12 @@ func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
func CleanupFileSources(c *gin.Context) { func CleanupFileSources(c *gin.Context) {
key := string(constant.ContextKeyFileSourcesToCleanup) key := string(constant.ContextKeyFileSourcesToCleanup)
if sources, exists := c.Get(key); exists { if sources, exists := c.Get(key); exists {
for _, source := range sources.([]*types.FileSource) { for _, source := range sources.([]types.FileSource) {
if cache := source.GetCache(); cache != nil { if cache := source.GetCache(); cache != nil {
cache.Close() cache.Close()
} }
} }
c.Set(key, nil) // 清除引用 c.Set(key, nil)
} }
} }
@@ -363,7 +384,7 @@ func loadFromBase64(base64String string, providedMimeType string) (*types.Cached
} }
// GetImageConfig 获取图片配置 // GetImageConfig 获取图片配置
func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) { func GetImageConfig(c *gin.Context, source types.FileSource) (image.Config, string, error) {
cachedData, err := LoadFileSource(c, source, "get_image_config") cachedData, err := LoadFileSource(c, source, "get_image_config")
if err != nil { if err != nil {
return image.Config{}, "", err return image.Config{}, "", err
@@ -394,7 +415,7 @@ func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, str
} }
// GetBase64Data 获取 base64 编码的数据 // GetBase64Data 获取 base64 编码的数据
func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) { func GetBase64Data(c *gin.Context, source types.FileSource, reason ...string) (string, string, error) {
cachedData, err := LoadFileSource(c, source, reason...) cachedData, err := LoadFileSource(c, source, reason...)
if err != nil { if err != nil {
return "", "", err return "", "", err
@@ -407,13 +428,13 @@ func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (
} }
// GetMimeType 获取文件的 MIME 类型 // GetMimeType 获取文件的 MIME 类型
func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) { func GetMimeType(c *gin.Context, source types.FileSource) (string, error) {
if source.HasCache() { if source.HasCache() {
return source.GetCache().MimeType, nil return source.GetCache().MimeType, nil
} }
if source.IsURL() { if urlSource, ok := source.(*types.URLSource); ok {
mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type") mimeType, err := GetFileTypeFromUrl(c, urlSource.URL, "get_mime_type")
if err == nil && mimeType != "" && mimeType != "application/octet-stream" { if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
return mimeType, nil return mimeType, nil
} }
-3
View File
@@ -100,8 +100,6 @@ func getImageToken(c *gin.Context, fileMeta *types.FileMeta, model string, strea
if err != nil { if err != nil {
return 0, err return 0, err
} }
fileMeta.MimeType = format
if config.Width == 0 || config.Height == 0 { if config.Width == 0 || config.Height == 0 {
// not an image, but might be a valid file // not an image, but might be a valid file
if format != "" { if format != "" {
@@ -268,7 +266,6 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
} }
continue continue
} }
file.MimeType = cachedData.MimeType
file.FileType = DetectFileType(cachedData.MimeType) file.FileType = DetectFileType(cachedData.MimeType)
} }
} }
+128 -127
View File
@@ -4,39 +4,144 @@ import (
"fmt" "fmt"
"image" "image"
"os" "os"
"strings"
"sync" "sync"
) )
// FileSourceType 文件来源类型 // FileSource 统一的文件来源抽象接口
type FileSourceType string
const (
FileSourceTypeURL FileSourceType = "url" // URL 来源
FileSourceTypeBase64 FileSourceType = "base64" // Base64 内联数据
)
// FileSource 统一的文件来源抽象
// 支持 URL 和 base64 两种来源,提供懒加载和缓存机制 // 支持 URL 和 base64 两种来源,提供懒加载和缓存机制
type FileSource struct { type FileSource interface {
Type FileSourceType `json:"type"` // 来源类型 IsURL() bool
URL string `json:"url,omitempty"` // URL(当 Type 为 url 时) GetIdentifier() string
Base64Data string `json:"base64_data,omitempty"` // Base64 数据(当 Type 为 base64 时) GetRawData() string
MimeType string `json:"mime_type,omitempty"` // MIME 类型(可选,会自动检测) ClearRawData()
// 内部缓存(不导出,不序列化) SetCache(data *CachedFileData)
GetCache() *CachedFileData
HasCache() bool
ClearCache()
IsRegistered() bool
SetRegistered(registered bool)
Mu() *sync.Mutex
}
// baseFileSource 共享的缓存/锁/清理注册状态
type baseFileSource struct {
cachedData *CachedFileData cachedData *CachedFileData
cacheLoaded bool cacheLoaded bool
registered bool // 是否已注册到清理列表 registered bool
mu sync.Mutex // 保护加载过程 mu sync.Mutex
} }
// Mu 获取内部锁 func (b *baseFileSource) SetCache(data *CachedFileData) {
func (f *FileSource) Mu() *sync.Mutex { b.cachedData = data
return &f.mu b.cacheLoaded = true
} }
// CachedFileData 缓存的文件数据 func (b *baseFileSource) GetCache() *CachedFileData {
// 支持内存缓存和磁盘缓存两种模式 return b.cachedData
}
func (b *baseFileSource) HasCache() bool {
return b.cacheLoaded && b.cachedData != nil
}
func (b *baseFileSource) ClearCache() {
if b.cachedData != nil {
b.cachedData.Close()
}
b.cachedData = nil
b.cacheLoaded = false
}
func (b *baseFileSource) IsRegistered() bool {
return b.registered
}
func (b *baseFileSource) SetRegistered(registered bool) {
b.registered = registered
}
func (b *baseFileSource) Mu() *sync.Mutex {
return &b.mu
}
// ---------------------------------------------------------------------------
// URLSource — URL 来源的 FileSource 实现
// ---------------------------------------------------------------------------
type URLSource struct {
baseFileSource
URL string
}
func (u *URLSource) IsURL() bool { return true }
func (u *URLSource) GetIdentifier() string {
if len(u.URL) > 100 {
return u.URL[:100] + "..."
}
return u.URL
}
func (u *URLSource) GetRawData() string { return u.URL }
func (u *URLSource) ClearRawData() {}
// ---------------------------------------------------------------------------
// Base64Source — Base64 内联数据来源的 FileSource 实现
// ---------------------------------------------------------------------------
type Base64Source struct {
baseFileSource
Base64Data string
MimeType string
}
func (b *Base64Source) IsURL() bool { return false }
func (b *Base64Source) GetIdentifier() string {
if len(b.Base64Data) > 50 {
return "base64:" + b.Base64Data[:50] + "..."
}
return "base64:" + b.Base64Data
}
func (b *Base64Source) GetRawData() string { return b.Base64Data }
func (b *Base64Source) ClearRawData() {
if len(b.Base64Data) > 1024 {
b.Base64Data = ""
}
}
// ---------------------------------------------------------------------------
// Constructors
// ---------------------------------------------------------------------------
func NewURLFileSource(url string) *URLSource {
return &URLSource{URL: url}
}
func NewBase64FileSource(base64Data string, mimeType string) *Base64Source {
return &Base64Source{
Base64Data: base64Data,
MimeType: mimeType,
}
}
func NewFileSourceFromData(data string, mimeType string) FileSource {
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
return NewURLFileSource(data)
}
return NewBase64FileSource(data, mimeType)
}
// ---------------------------------------------------------------------------
// CachedFileData — 缓存的文件数据(支持内存和磁盘两种模式)
// ---------------------------------------------------------------------------
type CachedFileData struct { type CachedFileData struct {
base64Data string // 内存中的 base64 数据(小文件) base64Data string // 内存中的 base64 数据(小文件)
MimeType string // MIME 类型 MimeType string // MIME 类型
@@ -45,18 +150,15 @@ type CachedFileData struct {
ImageConfig *image.Config // 图片配置(如果是图片) ImageConfig *image.Config // 图片配置(如果是图片)
ImageFormat string // 图片格式(如果是图片) ImageFormat string // 图片格式(如果是图片)
// 磁盘缓存相关
diskPath string // 磁盘缓存文件路径(大文件) diskPath string // 磁盘缓存文件路径(大文件)
isDisk bool // 是否使用磁盘缓存 isDisk bool // 是否使用磁盘缓存
diskMu sync.Mutex // 磁盘操作锁(保护磁盘文件的读取和删除) diskMu sync.Mutex // 磁盘操作锁(保护磁盘文件的读取和删除)
diskClosed bool // 是否已关闭/清理 diskClosed bool // 是否已关闭/清理
statDecremented bool // 是否已扣减统计 statDecremented bool // 是否已扣减统计
// 统计回调,避免循环依赖
OnClose func(size int64) OnClose func(size int64)
} }
// NewMemoryCachedData 创建内存缓存的数据
func NewMemoryCachedData(base64Data string, mimeType string, size int64) *CachedFileData { func NewMemoryCachedData(base64Data string, mimeType string, size int64) *CachedFileData {
return &CachedFileData{ return &CachedFileData{
base64Data: base64Data, base64Data: base64Data,
@@ -66,7 +168,6 @@ func NewMemoryCachedData(base64Data string, mimeType string, size int64) *Cached
} }
} }
// NewDiskCachedData 创建磁盘缓存的数据
func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFileData { func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFileData {
return &CachedFileData{ return &CachedFileData{
diskPath: diskPath, diskPath: diskPath,
@@ -76,7 +177,6 @@ func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFile
} }
} }
// GetBase64Data 获取 base64 数据(自动处理内存/磁盘)
func (c *CachedFileData) GetBase64Data() (string, error) { func (c *CachedFileData) GetBase64Data() (string, error) {
if !c.isDisk { if !c.isDisk {
return c.base64Data, nil return c.base64Data, nil
@@ -89,7 +189,6 @@ func (c *CachedFileData) GetBase64Data() (string, error) {
return "", fmt.Errorf("disk cache already closed") return "", fmt.Errorf("disk cache already closed")
} }
// 从磁盘读取
data, err := os.ReadFile(c.diskPath) data, err := os.ReadFile(c.diskPath)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read from disk cache: %w", err) return "", fmt.Errorf("failed to read from disk cache: %w", err)
@@ -97,22 +196,19 @@ func (c *CachedFileData) GetBase64Data() (string, error) {
return string(data), nil return string(data), nil
} }
// SetBase64Data 设置 base64 数据(仅用于内存模式)
func (c *CachedFileData) SetBase64Data(data string) { func (c *CachedFileData) SetBase64Data(data string) {
if !c.isDisk { if !c.isDisk {
c.base64Data = data c.base64Data = data
} }
} }
// IsDisk 是否使用磁盘缓存
func (c *CachedFileData) IsDisk() bool { func (c *CachedFileData) IsDisk() bool {
return c.isDisk return c.isDisk
} }
// Close 关闭并清理资源
func (c *CachedFileData) Close() error { func (c *CachedFileData) Close() error {
if !c.isDisk { if !c.isDisk {
c.base64Data = "" // 释放内存 c.base64Data = ""
return nil return nil
} }
@@ -126,7 +222,6 @@ func (c *CachedFileData) Close() error {
c.diskClosed = true c.diskClosed = true
if c.diskPath != "" { if c.diskPath != "" {
err := os.Remove(c.diskPath) err := os.Remove(c.diskPath)
// 只有在删除成功且未扣减过统计时,才执行回调
if err == nil && !c.statDecremented && c.OnClose != nil { if err == nil && !c.statDecremented && c.OnClose != nil {
c.OnClose(c.DiskSize) c.OnClose(c.DiskSize)
c.statDecremented = true c.statDecremented = true
@@ -135,97 +230,3 @@ func (c *CachedFileData) Close() error {
} }
return nil return nil
} }
// NewURLFileSource 创建 URL 来源的 FileSource
func NewURLFileSource(url string) *FileSource {
return &FileSource{
Type: FileSourceTypeURL,
URL: url,
}
}
// NewBase64FileSource 创建 base64 来源的 FileSource
func NewBase64FileSource(base64Data string, mimeType string) *FileSource {
return &FileSource{
Type: FileSourceTypeBase64,
Base64Data: base64Data,
MimeType: mimeType,
}
}
// IsURL 判断是否是 URL 来源
func (f *FileSource) IsURL() bool {
return f.Type == FileSourceTypeURL
}
// IsBase64 判断是否是 base64 来源
func (f *FileSource) IsBase64() bool {
return f.Type == FileSourceTypeBase64
}
// GetIdentifier 获取文件标识符(用于日志和错误追踪)
func (f *FileSource) GetIdentifier() string {
if f.IsURL() {
if len(f.URL) > 100 {
return f.URL[:100] + "..."
}
return f.URL
}
if len(f.Base64Data) > 50 {
return "base64:" + f.Base64Data[:50] + "..."
}
return "base64:" + f.Base64Data
}
// GetRawData 获取原始数据(URL 或完整的 base64 字符串)
func (f *FileSource) GetRawData() string {
if f.IsURL() {
return f.URL
}
return f.Base64Data
}
// SetCache 设置缓存数据
func (f *FileSource) SetCache(data *CachedFileData) {
f.cachedData = data
f.cacheLoaded = true
}
// IsRegistered 是否已注册到清理列表
func (f *FileSource) IsRegistered() bool {
return f.registered
}
// SetRegistered 设置注册状态
func (f *FileSource) SetRegistered(registered bool) {
f.registered = registered
}
// GetCache 获取缓存数据
func (f *FileSource) GetCache() *CachedFileData {
return f.cachedData
}
// HasCache 是否有缓存
func (f *FileSource) HasCache() bool {
return f.cacheLoaded && f.cachedData != nil
}
// ClearCache 清除缓存,释放内存和磁盘文件
func (f *FileSource) ClearCache() {
// 如果有缓存数据,先关闭它(会清理磁盘文件)
if f.cachedData != nil {
f.cachedData.Close()
}
f.cachedData = nil
f.cacheLoaded = false
}
// ClearRawData 清除原始数据,只保留必要的元信息
// 用于在处理完成后释放大文件的内存
func (f *FileSource) ClearRawData() {
// 保留 URL(通常很短),只清除大的 base64 数据
if f.IsBase64() && len(f.Base64Data) > 1024 {
f.Base64Data = ""
}
}
+3 -4
View File
@@ -32,13 +32,12 @@ type TokenCountMeta struct {
type FileMeta struct { type FileMeta struct {
FileType FileType
MimeType string Source FileSource // 统一的文件来源(URL 或 base64)
Source *FileSource // 统一的文件来源(URL 或 base64)
Detail string // 图片细节级别(low/high/auto Detail string // 图片细节级别(low/high/auto
} }
// NewFileMeta 创建新的 FileMeta // NewFileMeta 创建新的 FileMeta
func NewFileMeta(fileType FileType, source *FileSource) *FileMeta { func NewFileMeta(fileType FileType, source FileSource) *FileMeta {
return &FileMeta{ return &FileMeta{
FileType: fileType, FileType: fileType,
Source: source, Source: source,
@@ -46,7 +45,7 @@ func NewFileMeta(fileType FileType, source *FileSource) *FileMeta {
} }
// NewImageFileMeta 创建图片类型的 FileMeta // NewImageFileMeta 创建图片类型的 FileMeta
func NewImageFileMeta(source *FileSource, detail string) *FileMeta { func NewImageFileMeta(source FileSource, detail string) *FileMeta {
return &FileMeta{ return &FileMeta{
FileType: FileTypeImage, FileType: FileTypeImage,
Source: source, Source: source,