diff --git a/dto/claude.go b/dto/claude.go index 73bfa9c5..d292f97e 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -98,6 +98,20 @@ func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage { return mediaContent } +func (m *ClaudeMediaMessage) ToFileSource() types.FileSource { + if m.Source == nil { + return nil + } + data := m.Source.Url + if data == "" { + data = common.Interface2String(m.Source.Data) + } + if data == "" { + return nil + } + return types.NewFileSourceFromData(data, m.Source.MediaType) +} + type ClaudeMessageSource struct { Type string `json:"type"` MediaType string `json:"media_type,omitempty"` @@ -223,14 +237,6 @@ type OutputConfigForEffort struct { Effort string `json:"effort,omitempty"` } -// createClaudeFileSource 根据数据内容创建正确类型的 FileSource -func createClaudeFileSource(data string) *types.FileSource { - if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") { - return types.NewURLFileSource(data) - } - return types.NewBase64FileSource(data, "") -} - func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { maxTokens := 0 if c.MaxTokens != nil { @@ -258,17 +264,11 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { case "text": texts = append(texts, media.GetText()) case "image": - if media.Source != nil { - data := media.Source.Url - if data == "" { - data = common.Interface2String(media.Source.Data) - } - if data != "" { - fileMeta = append(fileMeta, &types.FileMeta{ - FileType: types.FileTypeImage, - Source: createClaudeFileSource(data), - }) - } + if source := media.ToFileSource(); source != nil { + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeImage, + Source: source, + }) } } } @@ -293,17 +293,11 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { case "text": texts = append(texts, media.GetText()) case "image": - if media.Source != nil { - data := media.Source.Url - if data == "" { - data = common.Interface2String(media.Source.Data) - } - if data != "" { - fileMeta = append(fileMeta, &types.FileMeta{ - FileType: types.FileTypeImage, - Source: createClaudeFileSource(data), - }) - } + if source := media.ToFileSource(); source != nil { + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeImage, + Source: source, + }) } case "tool_use": if media.Name != "" { diff --git a/dto/gemini.go b/dto/gemini.go index 63c99b86..029c3f03 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -64,14 +64,6 @@ type LatLng struct { Longitude *float64 `json:"longitude,omitempty"` } -// createGeminiFileSource 根据数据内容创建正确类型的 FileSource -func createGeminiFileSource(data string, mimeType string) *types.FileSource { - if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") { - return types.NewURLFileSource(data) - } - return types.NewBase64FileSource(data, mimeType) -} - func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { var files []*types.FileMeta = make([]*types.FileMeta, 0) @@ -87,9 +79,8 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { if part.Text != "" { inputTexts = append(inputTexts, part.Text) } - if part.InlineData != nil && part.InlineData.Data != "" { + if source := part.InlineData.ToFileSource(); source != nil { mimeType := part.InlineData.MimeType - source := createGeminiFileSource(part.InlineData.Data, mimeType) var fileType types.FileType if strings.HasPrefix(mimeType, "image/") { fileType = types.FileTypeImage @@ -103,7 +94,6 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { files = append(files, &types.FileMeta{ FileType: fileType, Source: source, - MimeType: mimeType, }) } } @@ -215,6 +205,13 @@ type GeminiInlineData struct { Data string `json:"data"` } +func (d *GeminiInlineData) ToFileSource() types.FileSource { + if d == nil || d.Data == "" { + return nil + } + return types.NewFileSourceFromData(d.Data, d.MimeType) +} + // UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType func (g *GeminiInlineData) UnmarshalJSON(data []byte) error { type Alias GeminiInlineData // Use type alias to avoid recursion diff --git a/dto/openai_request.go b/dto/openai_request.go index 76a86662..25ef3a21 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -108,14 +108,6 @@ type GeneralOpenAIRequest struct { ReasoningSplit json.RawMessage `json:"reasoning_split,omitempty"` } -// createFileSource 根据数据内容创建正确类型的 FileSource -func createFileSource(data string) *types.FileSource { - if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") { - return types.NewURLFileSource(data) - } - return types.NewBase64FileSource(data, "") -} - func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { var tokenCountMeta types.TokenCountMeta var texts = make([]string, 0) @@ -159,44 +151,24 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { } arrayContent := message.ParseContent() for _, m := range arrayContent { - if m.Type == ContentTypeImageURL { - imageUrl := m.GetImageMedia() - if imageUrl != nil && imageUrl.Url != "" { - source := createFileSource(imageUrl.Url) - fileMeta = append(fileMeta, &types.FileMeta{ - FileType: types.FileTypeImage, - Source: source, - Detail: imageUrl.Detail, - }) + source := m.ToFileSource() + if source != nil { + meta := &types.FileMeta{Source: source} + switch m.Type { + case ContentTypeImageURL: + meta.FileType = types.FileTypeImage + if img := m.GetImageMedia(); img != nil { + meta.Detail = img.Detail + } + case ContentTypeInputAudio: + meta.FileType = types.FileTypeAudio + case ContentTypeFile: + meta.FileType = types.FileTypeFile + case ContentTypeVideoUrl: + meta.FileType = types.FileTypeVideo } - } else if m.Type == ContentTypeInputAudio { - inputAudio := m.GetInputAudio() - if inputAudio != nil && inputAudio.Data != "" { - source := createFileSource(inputAudio.Data) - fileMeta = append(fileMeta, &types.FileMeta{ - FileType: types.FileTypeAudio, - Source: source, - }) - } - } else if m.Type == ContentTypeFile { - file := m.GetFile() - if file != nil && file.FileData != "" { - source := createFileSource(file.FileData) - fileMeta = append(fileMeta, &types.FileMeta{ - FileType: types.FileTypeFile, - Source: source, - }) - } - } else if m.Type == ContentTypeVideoUrl { - videoUrl := m.GetVideoUrl() - if videoUrl != nil && videoUrl.Url != "" { - source := createFileSource(videoUrl.Url) - fileMeta = append(fileMeta, &types.FileMeta{ - FileType: types.FileTypeVideo, - Source: source, - }) - } - } else { + fileMeta = append(fileMeta, meta) + } else if m.Type == ContentTypeText { texts = append(texts, m.Text) } } @@ -391,6 +363,40 @@ func (m *MediaContent) GetVideoUrl() *MessageVideoUrl { return nil } +func (m *MediaContent) ToFileSource() types.FileSource { + switch m.Type { + case ContentTypeImageURL: + img := m.GetImageMedia() + if img == nil || img.Url == "" { + return nil + } + return types.NewFileSourceFromData(img.Url, img.MimeType) + case ContentTypeInputAudio: + audio := m.GetInputAudio() + if audio == nil || audio.Data == "" { + return nil + } + mimeType := "" + if audio.Format != "" { + mimeType = "audio/" + audio.Format + } + return types.NewFileSourceFromData(audio.Data, mimeType) + case ContentTypeFile: + file := m.GetFile() + if file == nil || file.FileData == "" { + return nil + } + return types.NewFileSourceFromData(file.FileData, "") + case ContentTypeVideoUrl: + video := m.GetVideoUrl() + if video == nil || video.Url == "" { + return nil + } + return types.NewFileSourceFromData(video.Url, "") + } + return nil +} + type MessageImageUrl struct { Url string `json:"url"` Detail string `json:"detail,omitempty"` @@ -865,7 +871,7 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { if input.ImageUrl != "" { fileMeta = append(fileMeta, &types.FileMeta{ FileType: types.FileTypeImage, - Source: createFileSource(input.ImageUrl), + Source: types.NewFileSourceFromData(input.ImageUrl, ""), Detail: input.Detail, }) } @@ -873,7 +879,7 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { if input.FileUrl != "" { fileMeta = append(fileMeta, &types.FileMeta{ FileType: types.FileTypeFile, - Source: createFileSource(input.FileUrl), + Source: types.NewFileSourceFromData(input.FileUrl, ""), }) } } else { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index ba97bc90..dceff5e7 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -1,12 +1,10 @@ package claude import ( - "encoding/base64" "encoding/json" "fmt" "io" "net/http" - "path/filepath" "strings" "github.com/QuantumNous/new-api/common" @@ -46,61 +44,6 @@ func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) { } } -func createClaudeFileSource(file *dto.MessageFile) *types.FileSource { - if file == nil || file.FileData == "" { - return nil - } - if strings.HasPrefix(file.FileData, "http://") || strings.HasPrefix(file.FileData, "https://") { - return types.NewURLFileSource(file.FileData) - } - mimeType := "" - if ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(file.FileName)), "."); ext != "" { - if detected := service.GetMimeTypeByExtension(ext); detected != "application/octet-stream" { - mimeType = detected - } - } - return types.NewBase64FileSource(file.FileData, mimeType) -} - -func buildClaudeFileMessage(c *gin.Context, file *dto.MessageFile) (*dto.ClaudeMediaMessage, error) { - source := createClaudeFileSource(file) - if source == nil { - return nil, nil - } - base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting document for Claude") - if err != nil { - return nil, fmt.Errorf("get file data failed: %w", err) - } - switch strings.ToLower(mimeType) { - case "application/pdf": - return &dto.ClaudeMediaMessage{ - Type: "document", - Source: &dto.ClaudeMessageSource{ - Type: "base64", - MediaType: mimeType, - Data: base64Data, - }, - }, nil - case "text/plain": - decodedData, err := base64.StdEncoding.DecodeString(base64Data) - if err != nil { - return nil, fmt.Errorf("decode text file data failed: %w", err) - } - return &dto.ClaudeMediaMessage{ - Type: "text", - Text: common.GetPointer(string(decodedData)), - }, nil - default: - msg := fmt.Sprintf("claude: skip unsupported file content, filename=%q, mime=%q", file.FileName, mimeType) - if c != nil { - logger.LogInfo(c, msg) - } else { - common.SysLog(msg) - } - return nil, nil - } -} - func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) { claudeTools := make([]any, 0, len(textRequest.Tools)) @@ -142,7 +85,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe // 解析 UserLocation JSON var userLocationMap map[string]interface{} - if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil { + if err := common.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil { // 检查是否有 approximate 字段 if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok { if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" { @@ -406,44 +349,33 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe Type: "text", Text: common.GetPointer[string](mediaMessage.Text), }) - case dto.ContentTypeImageURL: - claudeMediaMessage := dto.ClaudeMediaMessage{ - Type: "image", - Source: &dto.ClaudeMessageSource{ - Type: "base64", - }, - } - imageUrl := mediaMessage.GetImageMedia() - if imageUrl == nil { + default: + source := mediaMessage.ToFileSource() + if source == nil { continue } - // 使用统一的文件服务获取图片数据 - var source *types.FileSource - if strings.HasPrefix(imageUrl.Url, "http") { - source = types.NewURLFileSource(imageUrl.Url) - } else { - source = types.NewBase64FileSource(imageUrl.Url, "") - } base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude") if err != nil { return nil, fmt.Errorf("get file data failed: %s", err.Error()) } + claudeMediaMessage := dto.ClaudeMediaMessage{ + Source: &dto.ClaudeMessageSource{ + Type: "base64", + }, + } + if strings.HasPrefix(mimeType, "application/pdf") { + claudeMediaMessage.Type = "document" + } else { + claudeMediaMessage.Type = "image" + } + claudeMediaMessage.Source.MediaType = mimeType claudeMediaMessage.Source.Data = base64Data claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) - // FIXME - //case dto.ContentTypeFile: - // claudeFileMessage, err := buildClaudeFileMessage(c, mediaMessage.GetFile()) - // if err != nil { - // return nil, err - // } - // if claudeFileMessage != nil { - // claudeMediaMessages = append(claudeMediaMessages, *claudeFileMessage) - // } - default: continue } } + if message.ToolCalls != nil { for _, toolCall := range message.ParseToolCalls() { inputObj := make(map[string]any) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 2f1e7ecb..69175e76 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -585,14 +585,10 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i Text: part.Text, }) } - } else if part.Type == dto.ContentTypeImageURL { - // 使用统一的文件服务获取图片数据 - var source *types.FileSource - imageUrl := part.GetImageMedia().Url - if strings.HasPrefix(imageUrl, "http") { - source = types.NewURLFileSource(imageUrl) - } else { - source = types.NewBase64FileSource(imageUrl, "") + } else { + source := part.ToFileSource() + if source == nil { + continue } base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini") if err != nil { @@ -604,36 +600,6 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList()) } - parts = append(parts, dto.GeminiPart{ - InlineData: &dto.GeminiInlineData{ - MimeType: mimeType, - Data: base64Data, - }, - }) - } else if part.Type == dto.ContentTypeFile { - if part.GetFile().FileId != "" { - return nil, fmt.Errorf("only base64 file is supported in gemini") - } - fileSource := types.NewBase64FileSource(part.GetFile().FileData, "") - base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini") - if err != nil { - return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error()) - } - parts = append(parts, dto.GeminiPart{ - InlineData: &dto.GeminiInlineData{ - MimeType: mimeType, - Data: base64Data, - }, - }) - } else if part.Type == dto.ContentTypeInputAudio { - if part.GetInputAudio().Data == "" { - return nil, fmt.Errorf("only base64 audio is supported in gemini") - } - audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format) - base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini") - if err != nil { - return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error()) - } parts = append(parts, dto.GeminiPart{ InlineData: &dto.GeminiInlineData{ MimeType: mimeType, diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index afc27160..975c244c 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -98,15 +98,8 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam parts := m.ParseContent() for _, part := range parts { if part.Type == dto.ContentTypeImageURL { - img := part.GetImageMedia() - if img != nil && img.Url != "" { - // 使用统一的文件服务获取图片数据 - var source *types.FileSource - if strings.HasPrefix(img.Url, "http") { - source = types.NewURLFileSource(img.Url) - } else { - source = types.NewBase64FileSource(img.Url, "") - } + source := part.ToFileSource() + if source != nil { base64Data, _, err := service.GetBase64Data(c, source, "fetch image for ollama chat") if err != nil { return nil, err diff --git a/service/file_service.go b/service/file_service.go index 918e426d..bcf47442 100644 --- a/service/file_service.go +++ b/service/file_service.go @@ -25,14 +25,26 @@ import ( // FileService 统一的文件处理服务 // 提供文件下载、解码、缓存等功能的统一入口 -// getContextCacheKey 生成 context 缓存的 key +// getContextCacheKey 生成 URL context 缓存的 key func getContextCacheKey(url string) string { return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url)) } +// getBase64ContextCacheKey 生成 base64 context 缓存的 key +// 使用 length + MIME + 前 128 字符作为输入,避免对整个 base64 数据做 hash +func getBase64ContextCacheKey(data string, mimeType string) string { + keyMaterial := fmt.Sprintf("%d:%s:", len(data), mimeType) + if len(data) > 128 { + keyMaterial += data[:128] + } else { + keyMaterial += data + } + return fmt.Sprintf("b64_cache_%s", common.GenerateHMAC(keyMaterial)) +} + // LoadFileSource 加载文件源数据 // 这是统一的入口,会自动处理缓存和不同的来源类型 -func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) { +func LoadFileSource(c *gin.Context, source types.FileSource, reason ...string) (*types.CachedFileData, error) { if source == nil { return nil, fmt.Errorf("file source is nil") } @@ -43,7 +55,6 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) // 1. 快速检查内部缓存 if source.HasCache() { - // 即使命中内部缓存,也要确保注册到清理列表(如果尚未注册) if c != nil { registerSourceForCleanup(c, source) } @@ -62,39 +73,49 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) return source.GetCache(), nil } - // 4. 如果是 URL,检查 Context 缓存 - var contextKey string - if source.IsURL() && c != nil { - contextKey = getContextCacheKey(source.URL) - if cachedData, exists := c.Get(contextKey); exists { - data := cachedData.(*types.CachedFileData) - source.SetCache(data) - registerSourceForCleanup(c, source) - return data, nil - } - } - - // 5. 执行加载逻辑 + // 4. 根据来源类型加载(含 URL context 缓存查找) var cachedData *types.CachedFileData + var contextKey string var err error - if source.IsURL() { - cachedData, err = loadFromURL(c, source.URL, reason...) - } else { - cachedData, err = loadFromBase64(source.Base64Data, source.MimeType) + switch s := source.(type) { + case *types.URLSource: + if c != nil { + contextKey = getContextCacheKey(s.URL) + if cached, exists := c.Get(contextKey); exists { + data := cached.(*types.CachedFileData) + source.SetCache(data) + registerSourceForCleanup(c, source) + return data, nil + } + } + cachedData, err = loadFromURL(c, s.URL, reason...) + case *types.Base64Source: + if c != nil { + contextKey = getBase64ContextCacheKey(s.Base64Data, s.MimeType) + if cached, exists := c.Get(contextKey); exists { + data := cached.(*types.CachedFileData) + source.SetCache(data) + registerSourceForCleanup(c, source) + return data, nil + } + } + cachedData, err = loadFromBase64(s.Base64Data, s.MimeType) + default: + return nil, fmt.Errorf("unsupported file source type: %T", source) } if err != nil { return nil, err } - // 6. 设置缓存 + // 5. 设置缓存 source.SetCache(cachedData) if contextKey != "" && c != nil { c.Set(contextKey, cachedData) } - // 7. 注册到 context 以便请求结束时自动清理 + // 6. 注册到 context 以便请求结束时自动清理 if c != nil { registerSourceForCleanup(c, source) } @@ -103,15 +124,15 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) } // registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理 -func registerSourceForCleanup(c *gin.Context, source *types.FileSource) { +func registerSourceForCleanup(c *gin.Context, source types.FileSource) { if source.IsRegistered() { return } key := string(constant.ContextKeyFileSourcesToCleanup) - var sources []*types.FileSource + var sources []types.FileSource if existing, exists := c.Get(key); exists { - sources = existing.([]*types.FileSource) + sources = existing.([]types.FileSource) } sources = append(sources, source) c.Set(key, sources) @@ -123,12 +144,12 @@ func registerSourceForCleanup(c *gin.Context, source *types.FileSource) { func CleanupFileSources(c *gin.Context) { key := string(constant.ContextKeyFileSourcesToCleanup) if sources, exists := c.Get(key); exists { - for _, source := range sources.([]*types.FileSource) { + for _, source := range sources.([]types.FileSource) { if cache := source.GetCache(); cache != nil { cache.Close() } } - c.Set(key, nil) // 清除引用 + c.Set(key, nil) } } @@ -363,7 +384,7 @@ func loadFromBase64(base64String string, providedMimeType string) (*types.Cached } // GetImageConfig 获取图片配置 -func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) { +func GetImageConfig(c *gin.Context, source types.FileSource) (image.Config, string, error) { cachedData, err := LoadFileSource(c, source, "get_image_config") if err != nil { return image.Config{}, "", err @@ -394,7 +415,7 @@ func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, str } // GetBase64Data 获取 base64 编码的数据 -func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) { +func GetBase64Data(c *gin.Context, source types.FileSource, reason ...string) (string, string, error) { cachedData, err := LoadFileSource(c, source, reason...) if err != nil { return "", "", err @@ -407,13 +428,13 @@ func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) ( } // GetMimeType 获取文件的 MIME 类型 -func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) { +func GetMimeType(c *gin.Context, source types.FileSource) (string, error) { if source.HasCache() { return source.GetCache().MimeType, nil } - if source.IsURL() { - mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type") + if urlSource, ok := source.(*types.URLSource); ok { + mimeType, err := GetFileTypeFromUrl(c, urlSource.URL, "get_mime_type") if err == nil && mimeType != "" && mimeType != "application/octet-stream" { return mimeType, nil } diff --git a/service/token_counter.go b/service/token_counter.go index 7d648d77..63b76d97 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -100,8 +100,6 @@ func getImageToken(c *gin.Context, fileMeta *types.FileMeta, model string, strea if err != nil { return 0, err } - fileMeta.MimeType = format - if config.Width == 0 || config.Height == 0 { // not an image, but might be a valid file if format != "" { @@ -268,7 +266,6 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela } continue } - file.MimeType = cachedData.MimeType file.FileType = DetectFileType(cachedData.MimeType) } } diff --git a/types/file_source.go b/types/file_source.go index c52062d7..86ef25e1 100644 --- a/types/file_source.go +++ b/types/file_source.go @@ -4,39 +4,144 @@ import ( "fmt" "image" "os" + "strings" "sync" ) -// FileSourceType 文件来源类型 -type FileSourceType string - -const ( - FileSourceTypeURL FileSourceType = "url" // URL 来源 - FileSourceTypeBase64 FileSourceType = "base64" // Base64 内联数据 -) - -// FileSource 统一的文件来源抽象 +// FileSource 统一的文件来源抽象接口 // 支持 URL 和 base64 两种来源,提供懒加载和缓存机制 -type FileSource struct { - Type FileSourceType `json:"type"` // 来源类型 - URL string `json:"url,omitempty"` // URL(当 Type 为 url 时) - Base64Data string `json:"base64_data,omitempty"` // Base64 数据(当 Type 为 base64 时) - MimeType string `json:"mime_type,omitempty"` // MIME 类型(可选,会自动检测) +type FileSource interface { + IsURL() bool + GetIdentifier() string + GetRawData() string + ClearRawData() - // 内部缓存(不导出,不序列化) + SetCache(data *CachedFileData) + GetCache() *CachedFileData + HasCache() bool + ClearCache() + + IsRegistered() bool + SetRegistered(registered bool) + Mu() *sync.Mutex +} + +// baseFileSource 共享的缓存/锁/清理注册状态 +type baseFileSource struct { cachedData *CachedFileData cacheLoaded bool - registered bool // 是否已注册到清理列表 - mu sync.Mutex // 保护加载过程 + registered bool + mu sync.Mutex } -// Mu 获取内部锁 -func (f *FileSource) Mu() *sync.Mutex { - return &f.mu +func (b *baseFileSource) SetCache(data *CachedFileData) { + b.cachedData = data + b.cacheLoaded = true } -// CachedFileData 缓存的文件数据 -// 支持内存缓存和磁盘缓存两种模式 +func (b *baseFileSource) GetCache() *CachedFileData { + return b.cachedData +} + +func (b *baseFileSource) HasCache() bool { + return b.cacheLoaded && b.cachedData != nil +} + +func (b *baseFileSource) ClearCache() { + if b.cachedData != nil { + b.cachedData.Close() + } + b.cachedData = nil + b.cacheLoaded = false +} + +func (b *baseFileSource) IsRegistered() bool { + return b.registered +} + +func (b *baseFileSource) SetRegistered(registered bool) { + b.registered = registered +} + +func (b *baseFileSource) Mu() *sync.Mutex { + return &b.mu +} + +// --------------------------------------------------------------------------- +// URLSource — URL 来源的 FileSource 实现 +// --------------------------------------------------------------------------- + +type URLSource struct { + baseFileSource + URL string +} + +func (u *URLSource) IsURL() bool { return true } + +func (u *URLSource) GetIdentifier() string { + if len(u.URL) > 100 { + return u.URL[:100] + "..." + } + return u.URL +} + +func (u *URLSource) GetRawData() string { return u.URL } + +func (u *URLSource) ClearRawData() {} + +// --------------------------------------------------------------------------- +// Base64Source — Base64 内联数据来源的 FileSource 实现 +// --------------------------------------------------------------------------- + +type Base64Source struct { + baseFileSource + Base64Data string + MimeType string +} + +func (b *Base64Source) IsURL() bool { return false } + +func (b *Base64Source) GetIdentifier() string { + if len(b.Base64Data) > 50 { + return "base64:" + b.Base64Data[:50] + "..." + } + return "base64:" + b.Base64Data +} + +func (b *Base64Source) GetRawData() string { return b.Base64Data } + +func (b *Base64Source) ClearRawData() { + if len(b.Base64Data) > 1024 { + b.Base64Data = "" + } +} + +// --------------------------------------------------------------------------- +// Constructors +// --------------------------------------------------------------------------- + +func NewURLFileSource(url string) *URLSource { + return &URLSource{URL: url} +} + +func NewBase64FileSource(base64Data string, mimeType string) *Base64Source { + return &Base64Source{ + Base64Data: base64Data, + MimeType: mimeType, + } +} + +func NewFileSourceFromData(data string, mimeType string) FileSource { + if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") { + return NewURLFileSource(data) + } + return NewBase64FileSource(data, mimeType) +} + +// --------------------------------------------------------------------------- +// CachedFileData — 缓存的文件数据(支持内存和磁盘两种模式) +// --------------------------------------------------------------------------- + type CachedFileData struct { base64Data string // 内存中的 base64 数据(小文件) MimeType string // MIME 类型 @@ -45,18 +150,15 @@ type CachedFileData struct { ImageConfig *image.Config // 图片配置(如果是图片) ImageFormat string // 图片格式(如果是图片) - // 磁盘缓存相关 diskPath string // 磁盘缓存文件路径(大文件) isDisk bool // 是否使用磁盘缓存 diskMu sync.Mutex // 磁盘操作锁(保护磁盘文件的读取和删除) diskClosed bool // 是否已关闭/清理 statDecremented bool // 是否已扣减统计 - // 统计回调,避免循环依赖 OnClose func(size int64) } -// NewMemoryCachedData 创建内存缓存的数据 func NewMemoryCachedData(base64Data string, mimeType string, size int64) *CachedFileData { return &CachedFileData{ base64Data: base64Data, @@ -66,7 +168,6 @@ func NewMemoryCachedData(base64Data string, mimeType string, size int64) *Cached } } -// NewDiskCachedData 创建磁盘缓存的数据 func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFileData { return &CachedFileData{ diskPath: diskPath, @@ -76,7 +177,6 @@ func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFile } } -// GetBase64Data 获取 base64 数据(自动处理内存/磁盘) func (c *CachedFileData) GetBase64Data() (string, error) { if !c.isDisk { return c.base64Data, nil @@ -89,7 +189,6 @@ func (c *CachedFileData) GetBase64Data() (string, error) { return "", fmt.Errorf("disk cache already closed") } - // 从磁盘读取 data, err := os.ReadFile(c.diskPath) if err != nil { return "", fmt.Errorf("failed to read from disk cache: %w", err) @@ -97,22 +196,19 @@ func (c *CachedFileData) GetBase64Data() (string, error) { return string(data), nil } -// SetBase64Data 设置 base64 数据(仅用于内存模式) func (c *CachedFileData) SetBase64Data(data string) { if !c.isDisk { c.base64Data = data } } -// IsDisk 是否使用磁盘缓存 func (c *CachedFileData) IsDisk() bool { return c.isDisk } -// Close 关闭并清理资源 func (c *CachedFileData) Close() error { if !c.isDisk { - c.base64Data = "" // 释放内存 + c.base64Data = "" return nil } @@ -126,7 +222,6 @@ func (c *CachedFileData) Close() error { c.diskClosed = true if c.diskPath != "" { err := os.Remove(c.diskPath) - // 只有在删除成功且未扣减过统计时,才执行回调 if err == nil && !c.statDecremented && c.OnClose != nil { c.OnClose(c.DiskSize) c.statDecremented = true @@ -135,97 +230,3 @@ func (c *CachedFileData) Close() error { } return nil } - -// NewURLFileSource 创建 URL 来源的 FileSource -func NewURLFileSource(url string) *FileSource { - return &FileSource{ - Type: FileSourceTypeURL, - URL: url, - } -} - -// NewBase64FileSource 创建 base64 来源的 FileSource -func NewBase64FileSource(base64Data string, mimeType string) *FileSource { - return &FileSource{ - Type: FileSourceTypeBase64, - Base64Data: base64Data, - MimeType: mimeType, - } -} - -// IsURL 判断是否是 URL 来源 -func (f *FileSource) IsURL() bool { - return f.Type == FileSourceTypeURL -} - -// IsBase64 判断是否是 base64 来源 -func (f *FileSource) IsBase64() bool { - return f.Type == FileSourceTypeBase64 -} - -// GetIdentifier 获取文件标识符(用于日志和错误追踪) -func (f *FileSource) GetIdentifier() string { - if f.IsURL() { - if len(f.URL) > 100 { - return f.URL[:100] + "..." - } - return f.URL - } - if len(f.Base64Data) > 50 { - return "base64:" + f.Base64Data[:50] + "..." - } - return "base64:" + f.Base64Data -} - -// GetRawData 获取原始数据(URL 或完整的 base64 字符串) -func (f *FileSource) GetRawData() string { - if f.IsURL() { - return f.URL - } - return f.Base64Data -} - -// SetCache 设置缓存数据 -func (f *FileSource) SetCache(data *CachedFileData) { - f.cachedData = data - f.cacheLoaded = true -} - -// IsRegistered 是否已注册到清理列表 -func (f *FileSource) IsRegistered() bool { - return f.registered -} - -// SetRegistered 设置注册状态 -func (f *FileSource) SetRegistered(registered bool) { - f.registered = registered -} - -// GetCache 获取缓存数据 -func (f *FileSource) GetCache() *CachedFileData { - return f.cachedData -} - -// HasCache 是否有缓存 -func (f *FileSource) HasCache() bool { - return f.cacheLoaded && f.cachedData != nil -} - -// ClearCache 清除缓存,释放内存和磁盘文件 -func (f *FileSource) ClearCache() { - // 如果有缓存数据,先关闭它(会清理磁盘文件) - if f.cachedData != nil { - f.cachedData.Close() - } - f.cachedData = nil - f.cacheLoaded = false -} - -// ClearRawData 清除原始数据,只保留必要的元信息 -// 用于在处理完成后释放大文件的内存 -func (f *FileSource) ClearRawData() { - // 保留 URL(通常很短),只清除大的 base64 数据 - if f.IsBase64() && len(f.Base64Data) > 1024 { - f.Base64Data = "" - } -} diff --git a/types/request_meta.go b/types/request_meta.go index 2d909d0b..476ea052 100644 --- a/types/request_meta.go +++ b/types/request_meta.go @@ -32,13 +32,12 @@ type TokenCountMeta struct { type FileMeta struct { FileType - MimeType string - Source *FileSource // 统一的文件来源(URL 或 base64) - Detail string // 图片细节级别(low/high/auto) + Source FileSource // 统一的文件来源(URL 或 base64) + Detail string // 图片细节级别(low/high/auto) } // NewFileMeta 创建新的 FileMeta -func NewFileMeta(fileType FileType, source *FileSource) *FileMeta { +func NewFileMeta(fileType FileType, source FileSource) *FileMeta { return &FileMeta{ FileType: fileType, Source: source, @@ -46,7 +45,7 @@ func NewFileMeta(fileType FileType, source *FileSource) *FileMeta { } // NewImageFileMeta 创建图片类型的 FileMeta -func NewImageFileMeta(source *FileSource, detail string) *FileMeta { +func NewImageFileMeta(source FileSource, detail string) *FileMeta { return &FileMeta{ FileType: FileTypeImage, Source: source,