Files
new-api/service/file_service.go

608 lines
17 KiB
Go

package service
import (
"bytes"
"encoding/base64"
"encoding/binary"
"fmt"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"golang.org/x/image/webp"
)
// FileService 统一的文件处理服务
// 提供文件下载、解码、缓存等功能的统一入口
// getContextCacheKey 生成 URL context 缓存的 key
func getContextCacheKey(url string) string {
return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
}
// getBase64ContextCacheKey 生成 base64 context 缓存的 key
// 使用 length + MIME + 前 128 字符作为输入,避免对整个 base64 数据做 hash
func getBase64ContextCacheKey(data string, mimeType string) string {
keyMaterial := fmt.Sprintf("%d:%s:", len(data), mimeType)
if len(data) > 128 {
keyMaterial += data[:128]
} else {
keyMaterial += data
}
return fmt.Sprintf("b64_cache_%s", common.GenerateHMAC(keyMaterial))
}
// LoadFileSource 加载文件源数据
// 这是统一的入口,会自动处理缓存和不同的来源类型
func LoadFileSource(c *gin.Context, source types.FileSource, reason ...string) (*types.CachedFileData, error) {
if source == nil {
return nil, fmt.Errorf("file source is nil")
}
if common.DebugEnabled {
logger.LogDebug(c, fmt.Sprintf("LoadFileSource starting for: %s", source.GetIdentifier()))
}
// 1. 快速检查内部缓存
if source.HasCache() {
if c != nil {
registerSourceForCleanup(c, source)
}
return source.GetCache(), nil
}
// 2. 加锁保护加载过程
source.Mu().Lock()
defer source.Mu().Unlock()
// 3. 双重检查
if source.HasCache() {
if c != nil {
registerSourceForCleanup(c, source)
}
return source.GetCache(), nil
}
// 4. 根据来源类型加载(含 URL context 缓存查找)
var cachedData *types.CachedFileData
var contextKey string
var err error
switch s := source.(type) {
case *types.URLSource:
if c != nil {
contextKey = getContextCacheKey(s.URL)
if cached, exists := c.Get(contextKey); exists {
data := cached.(*types.CachedFileData)
source.SetCache(data)
registerSourceForCleanup(c, source)
return data, nil
}
}
cachedData, err = loadFromURL(c, s.URL, reason...)
case *types.Base64Source:
if c != nil {
contextKey = getBase64ContextCacheKey(s.Base64Data, s.MimeType)
if cached, exists := c.Get(contextKey); exists {
data := cached.(*types.CachedFileData)
source.SetCache(data)
registerSourceForCleanup(c, source)
return data, nil
}
}
cachedData, err = loadFromBase64(s.Base64Data, s.MimeType)
default:
return nil, fmt.Errorf("unsupported file source type: %T", source)
}
if err != nil {
return nil, err
}
// 5. 设置缓存
source.SetCache(cachedData)
if contextKey != "" && c != nil {
c.Set(contextKey, cachedData)
}
// 6. 注册到 context 以便请求结束时自动清理
if c != nil {
registerSourceForCleanup(c, source)
}
return cachedData, nil
}
// registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
func registerSourceForCleanup(c *gin.Context, source types.FileSource) {
if source.IsRegistered() {
return
}
key := string(constant.ContextKeyFileSourcesToCleanup)
var sources []types.FileSource
if existing, exists := c.Get(key); exists {
sources = existing.([]types.FileSource)
}
sources = append(sources, source)
c.Set(key, sources)
source.SetRegistered(true)
}
// CleanupFileSources 清理请求中所有注册的 FileSource
// 应在请求结束时调用(通常由中间件自动调用)
func CleanupFileSources(c *gin.Context) {
key := string(constant.ContextKeyFileSourcesToCleanup)
if sources, exists := c.Get(key); exists {
for _, source := range sources.([]types.FileSource) {
if cache := source.GetCache(); cache != nil {
cache.Close()
}
}
c.Set(key, nil)
}
}
// loadFromURL 从 URL 加载文件
func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) {
// 下载文件
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
if common.DebugEnabled {
logger.LogDebug(c, "loadFromURL: initiating download")
}
resp, err := DoDownloadRequest(url, reason...)
if err != nil {
return nil, fmt.Errorf("failed to download file from %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
}
// 读取文件内容(限制大小)
if common.DebugEnabled {
logger.LogDebug(c, "loadFromURL: reading response body")
}
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
if err != nil {
return nil, fmt.Errorf("failed to read file content: %w", err)
}
if len(fileBytes) > maxFileSize {
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
}
// 转换为 base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
// 智能获取 MIME 类型
mimeType := smartDetectMimeType(resp, url, fileBytes)
// 判断是否使用磁盘缓存
base64Size := int64(len(base64Data))
var cachedData *types.CachedFileData
if shouldUseDiskCache(base64Size) {
// 使用磁盘缓存
diskPath, err := writeToDiskCache(base64Data)
if err != nil {
// 磁盘缓存失败,回退到内存
logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err))
cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
} else {
cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes)))
cachedData.DiskSize = base64Size
cachedData.OnClose = func(size int64) {
common.DecrementDiskFiles(size)
}
common.IncrementDiskFiles(base64Size)
if common.DebugEnabled {
logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size))
}
}
} else {
// 使用内存缓存
cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
}
// 如果是图片,尝试获取图片配置
if strings.HasPrefix(mimeType, "image/") {
if common.DebugEnabled {
logger.LogDebug(c, "loadFromURL: decoding image config")
}
config, format, err := decodeImageConfig(fileBytes)
if err == nil {
cachedData.ImageConfig = &config
cachedData.ImageFormat = format
// 如果通过图片解码获取了更准确的格式,更新 MIME 类型
if mimeType == "application/octet-stream" || mimeType == "" {
cachedData.MimeType = "image/" + format
}
}
}
return cachedData, nil
}
// shouldUseDiskCache 判断是否应该使用磁盘缓存
func shouldUseDiskCache(dataSize int64) bool {
return common.ShouldUseDiskCache(dataSize)
}
// writeToDiskCache 将数据写入磁盘缓存
func writeToDiskCache(base64Data string) (string, error) {
return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data)
}
// smartDetectMimeType 智能检测 MIME 类型
func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string {
// 1. 尝试从 Content-Type header 获取
mimeType := resp.Header.Get("Content-Type")
if idx := strings.Index(mimeType, ";"); idx != -1 {
mimeType = strings.TrimSpace(mimeType[:idx])
}
if mimeType != "" && mimeType != "application/octet-stream" {
return mimeType
}
// 2. 尝试从 Content-Disposition header 的 filename 获取
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
parts := strings.Split(cd, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(strings.ToLower(part), "filename=") {
name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
// 移除引号
if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
name = name[1 : len(name)-1]
}
if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
ext := strings.ToLower(name[dot+1:])
if ext != "" {
mt := GetMimeTypeByExtension(ext)
if mt != "application/octet-stream" {
return mt
}
}
}
break
}
}
}
// 3. 尝试从 URL 路径获取扩展名
mt := guessMimeTypeFromURL(url)
if mt != "application/octet-stream" {
return mt
}
// 4. 使用 http.DetectContentType 内容嗅探
if len(fileBytes) > 0 {
sniffed := http.DetectContentType(fileBytes)
if sniffed != "" && sniffed != "application/octet-stream" {
// 去除可能的 charset 参数
if idx := strings.Index(sniffed, ";"); idx != -1 {
sniffed = strings.TrimSpace(sniffed[:idx])
}
return sniffed
}
// 4.5 尝试 HEIF/HEIC 检测(Go 标准库不识别)
if heifMime := detectHEIF(fileBytes); heifMime != "" {
return heifMime
}
}
// 5. 尝试作为图片解码获取格式
if len(fileBytes) > 0 {
if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" {
return "image/" + strings.ToLower(format)
}
}
// 最终回退
return "application/octet-stream"
}
// loadFromBase64 从 base64 字符串加载文件
func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) {
var mimeType string
var cleanBase64 string
// 处理 data: 前缀
if strings.HasPrefix(base64String, "data:") {
idx := strings.Index(base64String, ",")
if idx != -1 {
header := base64String[:idx]
cleanBase64 = base64String[idx+1:]
if strings.Contains(header, ":") && strings.Contains(header, ";") {
mimeStart := strings.Index(header, ":") + 1
mimeEnd := strings.Index(header, ";")
if mimeStart < mimeEnd {
mimeType = header[mimeStart:mimeEnd]
}
}
} else {
cleanBase64 = base64String
}
} else {
cleanBase64 = base64String
}
if providedMimeType != "" {
mimeType = providedMimeType
}
decodedData, err := base64.StdEncoding.DecodeString(cleanBase64)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 data: %w", err)
}
base64Size := int64(len(cleanBase64))
var cachedData *types.CachedFileData
if shouldUseDiskCache(base64Size) {
diskPath, err := writeToDiskCache(cleanBase64)
if err != nil {
cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
} else {
cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData)))
cachedData.DiskSize = base64Size
cachedData.OnClose = func(size int64) {
common.DecrementDiskFiles(size)
}
common.IncrementDiskFiles(base64Size)
}
} else {
cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
}
if mimeType == "" || strings.HasPrefix(mimeType, "image/") {
config, format, err := decodeImageConfig(decodedData)
if err == nil {
cachedData.ImageConfig = &config
cachedData.ImageFormat = format
if mimeType == "" {
cachedData.MimeType = "image/" + format
}
}
}
return cachedData, nil
}
// GetImageConfig 获取图片配置
func GetImageConfig(c *gin.Context, source types.FileSource) (image.Config, string, error) {
cachedData, err := LoadFileSource(c, source, "get_image_config")
if err != nil {
return image.Config{}, "", err
}
if cachedData.ImageConfig != nil {
return *cachedData.ImageConfig, cachedData.ImageFormat, nil
}
base64Str, err := cachedData.GetBase64Data()
if err != nil {
return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err)
}
decodedData, err := base64.StdEncoding.DecodeString(base64Str)
if err != nil {
return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err)
}
config, format, err := decodeImageConfig(decodedData)
if err != nil {
return image.Config{}, "", err
}
cachedData.ImageConfig = &config
cachedData.ImageFormat = format
return config, format, nil
}
// GetBase64Data 获取 base64 编码的数据
func GetBase64Data(c *gin.Context, source types.FileSource, reason ...string) (string, string, error) {
cachedData, err := LoadFileSource(c, source, reason...)
if err != nil {
return "", "", err
}
base64Str, err := cachedData.GetBase64Data()
if err != nil {
return "", "", fmt.Errorf("failed to get base64 data: %w", err)
}
return base64Str, cachedData.MimeType, nil
}
// GetMimeType 获取文件的 MIME 类型
func GetMimeType(c *gin.Context, source types.FileSource) (string, error) {
if source.HasCache() {
return source.GetCache().MimeType, nil
}
if urlSource, ok := source.(*types.URLSource); ok {
mimeType, err := GetFileTypeFromUrl(c, urlSource.URL, "get_mime_type")
if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
return mimeType, nil
}
}
cachedData, err := LoadFileSource(c, source, "get_mime_type")
if err != nil {
return "", err
}
return cachedData.MimeType, nil
}
// DetectFileType 检测文件类型
func DetectFileType(mimeType string) types.FileType {
if strings.HasPrefix(mimeType, "image/") {
return types.FileTypeImage
}
if strings.HasPrefix(mimeType, "audio/") {
return types.FileTypeAudio
}
if strings.HasPrefix(mimeType, "video/") {
return types.FileTypeVideo
}
return types.FileTypeFile
}
// decodeImageConfig 从字节数据解码图片配置
func decodeImageConfig(data []byte) (image.Config, string, error) {
reader := bytes.NewReader(data)
config, format, err := image.DecodeConfig(reader)
if err == nil {
return config, format, nil
}
reader.Seek(0, io.SeekStart)
config, err = webp.DecodeConfig(reader)
if err == nil {
return config, "webp", nil
}
// Try HEIF/HEIC: parse ISOBMFF ispe box for dimensions
if heifMime := detectHEIF(data); heifMime != "" {
formatName := "heif"
if heifMime == "image/heic" {
formatName = "heic"
}
if w, h, ok := parseHEIFDimensions(data); ok {
return image.Config{Width: w, Height: h}, formatName, nil
}
return image.Config{}, "", fmt.Errorf("failed to decode HEIF/HEIC image dimensions")
}
return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format")
}
// detectHEIF checks ISOBMFF magic bytes to detect HEIC/HEIF files.
// Returns "image/heic", "image/heif", or "" if not recognized.
func detectHEIF(data []byte) string {
if len(data) < 12 {
return ""
}
// ISOBMFF: bytes[4:8] must be "ftyp"
if string(data[4:8]) != "ftyp" {
return ""
}
brand := string(data[8:12])
switch brand {
case "heic", "heix", "hevc", "hevx", "heim", "heis":
return "image/heic"
case "mif1", "msf1":
return "image/heif"
default:
return ""
}
}
// parseHEIFDimensions parses ISOBMFF box tree to find the ispe box
// and extract image width/height. Returns (width, height, ok).
func parseHEIFDimensions(data []byte) (int, int, bool) {
size := len(data)
if size < 12 {
return 0, 0, false
}
// Walk top-level boxes to find "meta"
offset := 0
for offset+8 <= size {
boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
boxType := string(data[offset+4 : offset+8])
headerLen := 8
if boxSize == 1 {
// 64-bit extended size
if offset+16 > size {
break
}
boxSize = int(binary.BigEndian.Uint64(data[offset+8 : offset+16]))
headerLen = 16
} else if boxSize == 0 {
// box extends to end of data
boxSize = size - offset
}
if boxSize < headerLen || offset+boxSize > size {
break
}
if boxType == "meta" {
// meta is a full box: 4 bytes version/flags after header
metaData := data[offset+headerLen : offset+boxSize]
if len(metaData) < 4 {
return 0, 0, false
}
return findISPE(metaData[4:])
}
offset += boxSize
}
return 0, 0, false
}
// findISPE recursively searches for the ispe box within container boxes.
// Path: meta -> iprp -> ipco -> ispe
func findISPE(data []byte) (int, int, bool) {
offset := 0
size := len(data)
for offset+8 <= size {
boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
boxType := string(data[offset+4 : offset+8])
if boxSize < 8 || offset+boxSize > size {
break
}
content := data[offset+8 : offset+boxSize]
switch boxType {
case "iprp", "ipco":
if w, h, ok := findISPE(content); ok {
return w, h, true
}
case "ispe":
// ispe is a full box: 4 bytes version/flags, then 4 bytes width, 4 bytes height
if len(content) >= 12 {
w := int(binary.BigEndian.Uint32(content[4:8]))
h := int(binary.BigEndian.Uint32(content[8:12]))
if w > 0 && h > 0 {
return w, h, true
}
}
}
offset += boxSize
}
return 0, 0, false
}
// guessMimeTypeFromURL 从 URL 猜测 MIME 类型
func guessMimeTypeFromURL(url string) string {
cleanedURL := url
if q := strings.Index(cleanedURL, "?"); q != -1 {
cleanedURL = cleanedURL[:q]
}
if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
last := cleanedURL[slash+1:]
if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
ext := strings.ToLower(last[dot+1:])
return GetMimeTypeByExtension(ext)
}
}
return "application/octet-stream"
}