refactor(file-source): unify file source creation and enhance caching mechanisms
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -46,61 +44,6 @@ func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
|
||||
}
|
||||
}
|
||||
|
||||
func createClaudeFileSource(file *dto.MessageFile) *types.FileSource {
|
||||
if file == nil || file.FileData == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.HasPrefix(file.FileData, "http://") || strings.HasPrefix(file.FileData, "https://") {
|
||||
return types.NewURLFileSource(file.FileData)
|
||||
}
|
||||
mimeType := ""
|
||||
if ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(file.FileName)), "."); ext != "" {
|
||||
if detected := service.GetMimeTypeByExtension(ext); detected != "application/octet-stream" {
|
||||
mimeType = detected
|
||||
}
|
||||
}
|
||||
return types.NewBase64FileSource(file.FileData, mimeType)
|
||||
}
|
||||
|
||||
func buildClaudeFileMessage(c *gin.Context, file *dto.MessageFile) (*dto.ClaudeMediaMessage, error) {
|
||||
source := createClaudeFileSource(file)
|
||||
if source == nil {
|
||||
return nil, nil
|
||||
}
|
||||
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting document for Claude")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file data failed: %w", err)
|
||||
}
|
||||
switch strings.ToLower(mimeType) {
|
||||
case "application/pdf":
|
||||
return &dto.ClaudeMediaMessage{
|
||||
Type: "document",
|
||||
Source: &dto.ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
MediaType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
}, nil
|
||||
case "text/plain":
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode text file data failed: %w", err)
|
||||
}
|
||||
return &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer(string(decodedData)),
|
||||
}, nil
|
||||
default:
|
||||
msg := fmt.Sprintf("claude: skip unsupported file content, filename=%q, mime=%q", file.FileName, mimeType)
|
||||
if c != nil {
|
||||
logger.LogInfo(c, msg)
|
||||
} else {
|
||||
common.SysLog(msg)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||
claudeTools := make([]any, 0, len(textRequest.Tools))
|
||||
|
||||
@@ -142,7 +85,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
|
||||
// 解析 UserLocation JSON
|
||||
var userLocationMap map[string]interface{}
|
||||
if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
|
||||
if err := common.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
|
||||
// 检查是否有 approximate 字段
|
||||
if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
|
||||
if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
|
||||
@@ -406,44 +349,33 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](mediaMessage.Text),
|
||||
})
|
||||
case dto.ContentTypeImageURL:
|
||||
claudeMediaMessage := dto.ClaudeMediaMessage{
|
||||
Type: "image",
|
||||
Source: &dto.ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
},
|
||||
}
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
if imageUrl == nil {
|
||||
default:
|
||||
source := mediaMessage.ToFileSource()
|
||||
if source == nil {
|
||||
continue
|
||||
}
|
||||
// 使用统一的文件服务获取图片数据
|
||||
var source *types.FileSource
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
source = types.NewURLFileSource(imageUrl.Url)
|
||||
} else {
|
||||
source = types.NewBase64FileSource(imageUrl.Url, "")
|
||||
}
|
||||
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file data failed: %s", err.Error())
|
||||
}
|
||||
claudeMediaMessage := dto.ClaudeMediaMessage{
|
||||
Source: &dto.ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
},
|
||||
}
|
||||
if strings.HasPrefix(mimeType, "application/pdf") {
|
||||
claudeMediaMessage.Type = "document"
|
||||
} else {
|
||||
claudeMediaMessage.Type = "image"
|
||||
}
|
||||
|
||||
claudeMediaMessage.Source.MediaType = mimeType
|
||||
claudeMediaMessage.Source.Data = base64Data
|
||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||
// FIXME
|
||||
//case dto.ContentTypeFile:
|
||||
// claudeFileMessage, err := buildClaudeFileMessage(c, mediaMessage.GetFile())
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// if claudeFileMessage != nil {
|
||||
// claudeMediaMessages = append(claudeMediaMessages, *claudeFileMessage)
|
||||
// }
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if message.ToolCalls != nil {
|
||||
for _, toolCall := range message.ParseToolCalls() {
|
||||
inputObj := make(map[string]any)
|
||||
|
||||
@@ -585,14 +585,10 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
Text: part.Text,
|
||||
})
|
||||
}
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
// 使用统一的文件服务获取图片数据
|
||||
var source *types.FileSource
|
||||
imageUrl := part.GetImageMedia().Url
|
||||
if strings.HasPrefix(imageUrl, "http") {
|
||||
source = types.NewURLFileSource(imageUrl)
|
||||
} else {
|
||||
source = types.NewBase64FileSource(imageUrl, "")
|
||||
} else {
|
||||
source := part.ToFileSource()
|
||||
if source == nil {
|
||||
continue
|
||||
}
|
||||
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini")
|
||||
if err != nil {
|
||||
@@ -604,36 +600,6 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeFile {
|
||||
if part.GetFile().FileId != "" {
|
||||
return nil, fmt.Errorf("only base64 file is supported in gemini")
|
||||
}
|
||||
fileSource := types.NewBase64FileSource(part.GetFile().FileData, "")
|
||||
base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeInputAudio {
|
||||
if part.GetInputAudio().Data == "" {
|
||||
return nil, fmt.Errorf("only base64 audio is supported in gemini")
|
||||
}
|
||||
audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format)
|
||||
base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
|
||||
@@ -98,15 +98,8 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
|
||||
parts := m.ParseContent()
|
||||
for _, part := range parts {
|
||||
if part.Type == dto.ContentTypeImageURL {
|
||||
img := part.GetImageMedia()
|
||||
if img != nil && img.Url != "" {
|
||||
// 使用统一的文件服务获取图片数据
|
||||
var source *types.FileSource
|
||||
if strings.HasPrefix(img.Url, "http") {
|
||||
source = types.NewURLFileSource(img.Url)
|
||||
} else {
|
||||
source = types.NewBase64FileSource(img.Url, "")
|
||||
}
|
||||
source := part.ToFileSource()
|
||||
if source != nil {
|
||||
base64Data, _, err := service.GetBase64Data(c, source, "fetch image for ollama chat")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user