Merge origin/main into nightly

Resolve conflicts:
- .gitignore: keep nightly additions (.test, skills-lock.json)
- relay/helper/price.go: keep both billingexpr and model imports
- en.json / zh-CN.json: keep nightly's superset of i18n entries
- service/billing_session.go: add missing 3rd arg to DecreaseUserQuota
- en.json / zh-CN.json: deduplicate 129+320 duplicate i18n keys
This commit is contained in:
CaIon
2026-04-23 21:23:40 +08:00
117 changed files with 9031 additions and 2910 deletions
+1 -1
View File
@@ -31,4 +31,4 @@ data/
.gopath .gopath
.test .test
token_estimator_test.go token_estimator_test.go
skills-lock.json skills-lock.json
+4
View File
@@ -116,6 +116,10 @@ var RetryTimes = 0
var IsMasterNode bool var IsMasterNode bool
// NodeName 节点名称,从 NODE_NAME 环境变量读取;
// 用于审计日志中标识节点身份,在容器/K8s 部署时比自动探测到的容器内网 IP 更具可读性。
var NodeName = ""
var requestInterval int var requestInterval int
var RequestInterval time.Duration var RequestInterval time.Duration
+1
View File
@@ -82,6 +82,7 @@ func InitEnv() {
DebugEnabled = os.Getenv("DEBUG") == "true" DebugEnabled = os.Getenv("DEBUG") == "true"
MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
IsMasterNode = os.Getenv("NODE_TYPE") != "slave" IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
NodeName = os.Getenv("NODE_NAME")
TLSInsecureSkipVerify = GetEnvOrDefaultBool("TLS_INSECURE_SKIP_VERIFY", false) TLSInsecureSkipVerify = GetEnvOrDefaultBool("TLS_INSECURE_SKIP_VERIFY", false)
if TLSInsecureSkipVerify { if TLSInsecureSkipVerify {
if tr, ok := http.DefaultTransport.(*http.Transport); ok && tr != nil { if tr, ok := http.DefaultTransport.(*http.Transport); ok && tr != nil {
+72 -28
View File
@@ -29,45 +29,89 @@ var DefaultSSRFProtection = &SSRFProtection{
AllowedPorts: []int{}, AllowedPorts: []int{},
} }
// isPrivateIP 检查IP是否为私有地址 // privateIPv4Nets IPv4 私有/保留/特殊用途网段
// 参考 IANA IPv4 Special-Purpose Address Registry
// https://www.iana.org/assignments/iana-ipv4-special-registry/
var privateIPv4Nets = []net.IPNet{
{IP: net.IPv4(0, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 0.0.0.0/8 ("This network" / 未指定)
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8 (私有)
{IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)}, // 100.64.0.0/10 (运营商级 NAT / CGNAT)
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8 (回环)
{IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12 (私有)
{IP: net.IPv4(192, 0, 0, 0), Mask: net.CIDRMask(24, 32)}, // 192.0.0.0/24 (IETF 协议分配)
{IP: net.IPv4(192, 0, 2, 0), Mask: net.CIDRMask(24, 32)}, // 192.0.2.0/24 (TEST-NET-1)
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16 (私有)
{IP: net.IPv4(198, 18, 0, 0), Mask: net.CIDRMask(15, 32)}, // 198.18.0.0/15 (基准测试)
{IP: net.IPv4(198, 51, 100, 0), Mask: net.CIDRMask(24, 32)}, // 198.51.100.0/24 (TEST-NET-2)
{IP: net.IPv4(203, 0, 113, 0), Mask: net.CIDRMask(24, 32)}, // 203.0.113.0/24 (TEST-NET-3)
{IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
{IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
{IP: net.IPv4(255, 255, 255, 255), Mask: net.CIDRMask(32, 32)}, // 255.255.255.255/32 (受限广播)
}
// privateIPv6Nets IPv6 私有/保留/特殊用途网段
// 参考 IANA IPv6 Special-Purpose Address Registry
// https://www.iana.org/assignments/iana-ipv6-special-registry/
var privateIPv6Nets = func() []net.IPNet {
cidrs := []string{
"::/128", // 未指定地址
"::1/128", // 回环
"::ffff:0:0/96", // IPv4-mapped
"64:ff9b::/96", // IPv4/IPv6 translation
"100::/64", // Discard-Only
"2001::/23", // IETF Protocol Assignments
"2001:db8::/32", // 文档
"fc00::/7", // Unique Local Address (ULA)
"fe80::/10", // 链路本地
"ff00::/8", // 组播
}
nets := make([]net.IPNet, 0, len(cidrs))
for _, c := range cidrs {
if _, n, err := net.ParseCIDR(c); err == nil && n != nil {
nets = append(nets, *n)
}
}
return nets
}()
// isPrivateIP 检查IP是否为私有/保留/特殊用途地址
func isPrivateIP(ip net.IP) bool { func isPrivateIP(ip net.IP) bool {
if ip == nil {
return true
}
// 未指定地址 (0.0.0.0, ::)
if ip.IsUnspecified() {
return true
}
// 回环、链路本地 (unicast/multicast)
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true return true
} }
// 接口本地组播 (IPv6 ff01::/16 等)
// 检查私有网段 if ip.IsInterfaceLocalMulticast() {
private := []net.IPNet{ return true
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
{IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
{IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
{IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
} }
for _, privateNet := range private { if v4 := ip.To4(); v4 != nil {
for _, privateNet := range privateIPv4Nets {
if privateNet.Contains(v4) {
return true
}
}
return false
}
// IPv6 检查
for _, privateNet := range privateIPv6Nets {
if privateNet.Contains(ip) { if privateNet.Contains(ip) {
return true return true
} }
} }
// 兜底: Go 标准库识别的其他私有地址
// 检查IPv6私有地址 if ip.IsPrivate() {
if ip.To4() == nil { return true
// IPv6 loopback
if ip.Equal(net.IPv6loopback) {
return true
}
// IPv6 link-local
if strings.HasPrefix(ip.String(), "fe80:") {
return true
}
// IPv6 unique local
if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") {
return true
}
} }
return false return false
} }
+1
View File
@@ -65,4 +65,5 @@ const (
// ContextKeyLanguage stores the user's language preference for i18n // ContextKeyLanguage stores the user's language preference for i18n
ContextKeyLanguage ContextKey = "language" ContextKeyLanguage ContextKey = "language"
ContextKeyIsStream ContextKey = "is_stream"
) )
+51 -9
View File
@@ -151,6 +151,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
} }
} }
cache.WriteContext(c) cache.WriteContext(c)
c.Set("id", 1)
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key) //c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json") c.Request.Header.Set("Content-Type", "application/json")
@@ -284,7 +285,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
return testResult{ return testResult{
context: c, context: c,
localErr: err, localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError), newAPIError: types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithStatusCode(http.StatusBadRequest)),
} }
} }
@@ -469,7 +470,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
} }
} }
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil { if bodyErr := validateTestResponseBody(respBody, isStream); bodyErr != nil {
return testResult{ return testResult{
context: c, context: c,
localErr: bodyErr, localErr: bodyErr,
@@ -613,6 +614,42 @@ func detectErrorFromTestResponseBody(respBody []byte) error {
return nil return nil
} }
func validateStreamTestResponseBody(respBody []byte) error {
b := bytes.TrimSpace(respBody)
if len(b) == 0 {
return errors.New("stream response body is empty")
}
for _, line := range bytes.Split(b, []byte{'\n'}) {
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
return nil
}
return errors.New("stream response body does not contain a valid stream event")
}
func validateTestResponseBody(respBody []byte, isStream bool) error {
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
return bodyErr
}
if isStream {
return validateStreamTestResponseBody(respBody)
}
return nil
}
func shouldUseStreamForAutomaticChannelTest(channel *model.Channel) bool {
return channel != nil && channel.Type == constant.ChannelTypeCodex
}
func detectErrorMessageFromJSONBytes(jsonBytes []byte) string { func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
if len(jsonBytes) == 0 { if len(jsonBytes) == 0 {
return "" return ""
@@ -800,11 +837,15 @@ func TestChannel(c *gin.Context) {
tik := time.Now() tik := time.Now()
result := testChannel(channel, testModel, endpointType, isStream) result := testChannel(channel, testModel, endpointType, isStream)
if result.localErr != nil { if result.localErr != nil {
c.JSON(http.StatusOK, gin.H{ resp := gin.H{
"success": false, "success": false,
"message": result.localErr.Error(), "message": result.localErr.Error(),
"time": 0.0, "time": 0.0,
}) }
if result.newAPIError != nil {
resp["error_code"] = result.newAPIError.GetErrorCode()
}
c.JSON(http.StatusOK, resp)
return return
} }
tok := time.Now() tok := time.Now()
@@ -813,9 +854,10 @@ func TestChannel(c *gin.Context) {
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
if result.newAPIError != nil { if result.newAPIError != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": result.newAPIError.Error(), "message": result.newAPIError.Error(),
"time": consumedTime, "time": consumedTime,
"error_code": result.newAPIError.GetErrorCode(),
}) })
return return
} }
@@ -860,7 +902,7 @@ func testAllChannels(notify bool) error {
} }
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
result := testChannel(channel, "", "", false) result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
@@ -868,7 +910,7 @@ func testAllChannels(notify bool) error {
newAPIError := result.newAPIError newAPIError := result.newAPIError
// request error disables the channel // request error disables the channel
if newAPIError != nil { if newAPIError != nil {
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError) shouldBanChannel = service.ShouldDisableChannel(result.newAPIError)
} }
// 当错误检查通过,才检查响应时间 // 当错误检查通过,才检查响应时间
+12 -2
View File
@@ -27,6 +27,15 @@ var completionRatioMetaOptionKeys = []string{
"AudioCompletionRatio", "AudioCompletionRatio",
} }
func isVisiblePublicKeyOption(key string) bool {
switch key {
case "WaffoPancakeWebhookPublicKey", "WaffoPancakeWebhookTestKey":
return true
default:
return false
}
}
func collectModelNamesFromOptionValue(raw string, modelNames map[string]struct{}) { func collectModelNamesFromOptionValue(raw string, modelNames map[string]struct{}) {
if strings.TrimSpace(raw) == "" { if strings.TrimSpace(raw) == "" {
return return
@@ -66,11 +75,12 @@ func GetOptions(c *gin.Context) {
common.OptionMapRWMutex.Lock() common.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap { for k, v := range common.OptionMap {
value := common.Interface2String(v) value := common.Interface2String(v)
if strings.HasSuffix(k, "Token") || isSensitiveKey := strings.HasSuffix(k, "Token") ||
strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Secret") ||
strings.HasSuffix(k, "Key") || strings.HasSuffix(k, "Key") ||
strings.HasSuffix(k, "secret") || strings.HasSuffix(k, "secret") ||
strings.HasSuffix(k, "api_key") { strings.HasSuffix(k, "api_key")
if isSensitiveKey && !isVisiblePublicKeyOption(k) {
continue continue
} }
options = append(options, &model.Option{ options = append(options, &model.Option{
+70
View File
@@ -36,6 +36,10 @@ func PasskeyRegisterBegin(c *gin.Context) {
return return
} }
if !requirePasskeyRegistrationVerification(c, user.Id) {
return
}
credential, err := model.GetPasskeyByUserID(user.Id) credential, err := model.GetPasskeyByUserID(user.Id)
if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) { if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) {
common.ApiError(c, err) common.ApiError(c, err)
@@ -96,6 +100,10 @@ func PasskeyRegisterFinish(c *gin.Context) {
return return
} }
if !requirePasskeyRegistrationVerification(c, user.Id) {
return
}
wa, err := passkeysvc.BuildWebAuthn(c.Request) wa, err := passkeysvc.BuildWebAuthn(c.Request)
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
@@ -151,6 +159,10 @@ func PasskeyDelete(c *gin.Context) {
return return
} }
if !requirePasskeyDeleteVerification(c, user.Id) {
return
}
if err := model.DeletePasskeyByUserID(user.Id); err != nil { if err := model.DeletePasskeyByUserID(user.Id); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
@@ -474,6 +486,7 @@ func PasskeyVerifyFinish(c *gin.Context) {
// Mark passkey as ready; /api/verify will convert this into the final secure verification session. // Mark passkey as ready; /api/verify will convert this into the final secure verification session.
session.Set(PasskeyReadySessionKey, time.Now().Unix()) session.Set(PasskeyReadySessionKey, time.Now().Unix())
session.Delete(SecureVerificationSessionKey) session.Delete(SecureVerificationSessionKey)
session.Delete(secureVerificationMethodSessionKey)
if err := session.Save(); err != nil { if err := session.Save(); err != nil {
common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err)) common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
return return
@@ -504,3 +517,60 @@ func getSessionUser(c *gin.Context) (*model.User, error) {
} }
return user, nil return user, nil
} }
func requirePasskeyRegistrationVerification(c *gin.Context, userID int) bool {
twoFA, err := model.GetTwoFAByUserId(userID)
if err != nil {
common.ApiError(c, err)
return false
}
if twoFA == nil || !twoFA.IsEnabled {
return true
}
return requireSecureVerificationMethod(c, secureVerificationMethod2FA)
}
func requirePasskeyDeleteVerification(c *gin.Context, userID int) bool {
twoFA, err := model.GetTwoFAByUserId(userID)
if err != nil {
common.ApiError(c, err)
return false
}
if twoFA != nil && twoFA.IsEnabled {
return requireSecureVerificationMethod(c, secureVerificationMethod2FA)
}
_, err = model.GetPasskeyByUserID(userID)
if err != nil {
if errors.Is(err, model.ErrPasskeyNotFound) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return false
}
common.ApiError(c, err)
return false
}
return requireSecureVerificationMethod(c, secureVerificationMethodPasskey)
}
func requireSecureVerificationMethod(c *gin.Context, method string) bool {
session := sessions.Default(c)
verifiedAt, ok := session.Get(SecureVerificationSessionKey).(int64)
if !ok || time.Now().Unix()-verifiedAt >= SecureVerificationTimeout {
session.Delete(SecureVerificationSessionKey)
session.Delete(secureVerificationMethodSessionKey)
_ = session.Save()
common.ApiErrorMsg(c, "请先完成安全验证")
return false
}
if verifiedMethod, ok := session.Get(secureVerificationMethodSessionKey).(string); !ok || verifiedMethod != method {
common.ApiErrorMsg(c, "请先完成对应的安全验证")
return false
}
return true
}
+100
View File
@@ -0,0 +1,100 @@
package controller
import (
"strings"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
)
func isStripeTopUpEnabled() bool {
return strings.TrimSpace(setting.StripeApiSecret) != "" &&
strings.TrimSpace(setting.StripeWebhookSecret) != "" &&
strings.TrimSpace(setting.StripePriceId) != ""
}
func isStripeWebhookConfigured() bool {
return strings.TrimSpace(setting.StripeWebhookSecret) != ""
}
func isStripeWebhookEnabled() bool {
return isStripeTopUpEnabled()
}
func isCreemTopUpEnabled() bool {
products := strings.TrimSpace(setting.CreemProducts)
return strings.TrimSpace(setting.CreemApiKey) != "" &&
products != "" &&
products != "[]"
}
func isCreemWebhookConfigured() bool {
return strings.TrimSpace(setting.CreemWebhookSecret) != ""
}
func isCreemWebhookEnabled() bool {
return isCreemTopUpEnabled() && isCreemWebhookConfigured()
}
func isWaffoTopUpEnabled() bool {
if !setting.WaffoEnabled {
return false
}
return isWaffoWebhookConfigured()
}
func isWaffoWebhookConfigured() bool {
if setting.WaffoSandbox {
return strings.TrimSpace(setting.WaffoSandboxApiKey) != "" &&
strings.TrimSpace(setting.WaffoSandboxPrivateKey) != "" &&
strings.TrimSpace(setting.WaffoSandboxPublicCert) != ""
}
return strings.TrimSpace(setting.WaffoApiKey) != "" &&
strings.TrimSpace(setting.WaffoPrivateKey) != "" &&
strings.TrimSpace(setting.WaffoPublicCert) != ""
}
func isWaffoWebhookEnabled() bool {
return isWaffoTopUpEnabled()
}
func isWaffoPancakeTopUpEnabled() bool {
if !setting.WaffoPancakeEnabled {
return false
}
return isWaffoPancakeWebhookConfigured() &&
strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
strings.TrimSpace(setting.WaffoPancakePrivateKey) != "" &&
strings.TrimSpace(setting.WaffoPancakeStoreID) != "" &&
strings.TrimSpace(setting.WaffoPancakeProductID) != ""
}
func isWaffoPancakeWebhookConfigured() bool {
currentWebhookKey := strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
if setting.WaffoPancakeSandbox {
currentWebhookKey = strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
}
return currentWebhookKey != ""
}
func isWaffoPancakeWebhookEnabled() bool {
return isWaffoPancakeTopUpEnabled()
}
func isEpayTopUpEnabled() bool {
return isEpayWebhookConfigured() && len(operation_setting.PayMethods) > 0
}
func isEpayWebhookConfigured() bool {
return strings.TrimSpace(operation_setting.PayAddress) != "" &&
strings.TrimSpace(operation_setting.EpayId) != "" &&
strings.TrimSpace(operation_setting.EpayKey) != ""
}
func isEpayWebhookEnabled() bool {
return isEpayTopUpEnabled()
}
@@ -0,0 +1,166 @@
package controller
import (
"testing"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/stretchr/testify/require"
)
func TestStripeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalAPISecret := setting.StripeApiSecret
originalWebhookSecret := setting.StripeWebhookSecret
originalPriceID := setting.StripePriceId
t.Cleanup(func() {
setting.StripeApiSecret = originalAPISecret
setting.StripeWebhookSecret = originalWebhookSecret
setting.StripePriceId = originalPriceID
})
setting.StripeWebhookSecret = ""
setting.StripeApiSecret = "sk_test_123"
setting.StripePriceId = "price_123"
require.False(t, isStripeWebhookEnabled())
setting.StripeWebhookSecret = "whsec_test"
require.True(t, isStripeWebhookEnabled())
setting.StripePriceId = ""
require.False(t, isStripeWebhookEnabled())
}
func TestCreemWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalAPIKey := setting.CreemApiKey
originalProducts := setting.CreemProducts
originalWebhookSecret := setting.CreemWebhookSecret
t.Cleanup(func() {
setting.CreemApiKey = originalAPIKey
setting.CreemProducts = originalProducts
setting.CreemWebhookSecret = originalWebhookSecret
})
setting.CreemWebhookSecret = ""
setting.CreemApiKey = "creem_api_key"
setting.CreemProducts = `[{"productId":"prod_123"}]`
require.False(t, isCreemWebhookEnabled())
setting.CreemWebhookSecret = "creem_secret"
require.True(t, isCreemWebhookEnabled())
setting.CreemProducts = "[]"
require.False(t, isCreemWebhookEnabled())
}
func TestWaffoWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalEnabled := setting.WaffoEnabled
originalSandbox := setting.WaffoSandbox
originalAPIKey := setting.WaffoApiKey
originalPrivateKey := setting.WaffoPrivateKey
originalPublicCert := setting.WaffoPublicCert
originalSandboxAPIKey := setting.WaffoSandboxApiKey
originalSandboxPrivateKey := setting.WaffoSandboxPrivateKey
originalSandboxPublicCert := setting.WaffoSandboxPublicCert
t.Cleanup(func() {
setting.WaffoEnabled = originalEnabled
setting.WaffoSandbox = originalSandbox
setting.WaffoApiKey = originalAPIKey
setting.WaffoPrivateKey = originalPrivateKey
setting.WaffoPublicCert = originalPublicCert
setting.WaffoSandboxApiKey = originalSandboxAPIKey
setting.WaffoSandboxPrivateKey = originalSandboxPrivateKey
setting.WaffoSandboxPublicCert = originalSandboxPublicCert
})
setting.WaffoEnabled = true
setting.WaffoSandbox = false
setting.WaffoApiKey = ""
setting.WaffoPrivateKey = "private"
setting.WaffoPublicCert = "public"
require.False(t, isWaffoWebhookEnabled())
setting.WaffoApiKey = "api"
require.True(t, isWaffoWebhookEnabled())
setting.WaffoEnabled = false
require.False(t, isWaffoWebhookEnabled())
setting.WaffoEnabled = true
setting.WaffoSandbox = true
setting.WaffoSandboxApiKey = ""
setting.WaffoSandboxPrivateKey = "sandbox_private"
setting.WaffoSandboxPublicCert = "sandbox_public"
require.False(t, isWaffoWebhookEnabled())
setting.WaffoSandboxApiKey = "sandbox_api"
require.True(t, isWaffoWebhookEnabled())
}
func TestWaffoPancakeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalEnabled := setting.WaffoPancakeEnabled
originalSandbox := setting.WaffoPancakeSandbox
originalMerchantID := setting.WaffoPancakeMerchantID
originalPrivateKey := setting.WaffoPancakePrivateKey
originalWebhookPublicKey := setting.WaffoPancakeWebhookPublicKey
originalWebhookTestKey := setting.WaffoPancakeWebhookTestKey
originalStoreID := setting.WaffoPancakeStoreID
originalProductID := setting.WaffoPancakeProductID
t.Cleanup(func() {
setting.WaffoPancakeEnabled = originalEnabled
setting.WaffoPancakeSandbox = originalSandbox
setting.WaffoPancakeMerchantID = originalMerchantID
setting.WaffoPancakePrivateKey = originalPrivateKey
setting.WaffoPancakeWebhookPublicKey = originalWebhookPublicKey
setting.WaffoPancakeWebhookTestKey = originalWebhookTestKey
setting.WaffoPancakeStoreID = originalStoreID
setting.WaffoPancakeProductID = originalProductID
})
setting.WaffoPancakeEnabled = true
setting.WaffoPancakeSandbox = false
setting.WaffoPancakeMerchantID = "merchant"
setting.WaffoPancakePrivateKey = "private"
setting.WaffoPancakeStoreID = "store"
setting.WaffoPancakeProductID = "product"
setting.WaffoPancakeWebhookPublicKey = ""
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeWebhookPublicKey = "public"
require.True(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeEnabled = false
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeEnabled = true
setting.WaffoPancakeSandbox = true
setting.WaffoPancakeWebhookTestKey = ""
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeWebhookTestKey = "test_public"
require.True(t, isWaffoPancakeWebhookEnabled())
}
func TestEpayWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalPayAddress := operation_setting.PayAddress
originalEpayID := operation_setting.EpayId
originalEpayKey := operation_setting.EpayKey
originalPayMethods := operation_setting.PayMethods
t.Cleanup(func() {
operation_setting.PayAddress = originalPayAddress
operation_setting.EpayId = originalEpayID
operation_setting.EpayKey = originalEpayKey
operation_setting.PayMethods = originalPayMethods
})
operation_setting.PayAddress = "https://pay.example.com"
operation_setting.EpayId = "epay_id"
operation_setting.EpayKey = ""
operation_setting.PayMethods = []map[string]string{{"type": "alipay"}}
require.False(t, isEpayWebhookEnabled())
operation_setting.EpayKey = "epay_key"
require.True(t, isEpayWebhookEnabled())
operation_setting.PayMethods = nil
require.False(t, isEpayWebhookEnabled())
}
+3 -3
View File
@@ -151,7 +151,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta) priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
if err != nil { if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError) newAPIError = types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithStatusCode(http.StatusBadRequest))
return return
} }
@@ -351,7 +351,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
if service.ShouldDisableChannel(channelError.ChannelType, err) && channelError.AutoBan { if service.ShouldDisableChannel(err) && channelError.AutoBan {
gopool.Go(func() { gopool.Go(func() {
service.DisableChannel(channelError, err.ErrorWithStatusCode()) service.DisableChannel(channelError, err.ErrorWithStatusCode())
}) })
@@ -389,7 +389,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
startTime = time.Now() startTime = time.Now()
} }
useTimeSeconds := int(time.Since(startTime).Seconds()) useTimeSeconds := int(time.Since(startTime).Seconds())
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, useTimeSeconds, false, userGroup, other) model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, useTimeSeconds, common.GetContextKeyBool(c, constant.ContextKeyIsStream), userGroup, other)
} }
} }
+7 -3
View File
@@ -13,7 +13,10 @@ import (
const ( const (
// SecureVerificationSessionKey means the user has fully passed secure verification. // SecureVerificationSessionKey means the user has fully passed secure verification.
SecureVerificationSessionKey = "secure_verified_at" SecureVerificationSessionKey = "secure_verified_at"
secureVerificationMethodSessionKey = "secure_verified_method"
secureVerificationMethod2FA = "2fa"
secureVerificationMethodPasskey = "passkey"
// PasskeyReadySessionKey means WebAuthn finished and /api/verify can finalize step-up verification. // PasskeyReadySessionKey means WebAuthn finished and /api/verify can finalize step-up verification.
PasskeyReadySessionKey = "secure_passkey_ready_at" PasskeyReadySessionKey = "secure_passkey_ready_at"
// SecureVerificationTimeout 验证有效期(秒) // SecureVerificationTimeout 验证有效期(秒)
@@ -120,7 +123,7 @@ func UniversalVerify(c *gin.Context) {
} }
// 验证成功,在 session 中记录时间戳 // 验证成功,在 session 中记录时间戳
now, err := setSecureVerificationSession(c) now, err := setSecureVerificationSession(c, req.Method)
if err != nil { if err != nil {
common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err)) common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
return return
@@ -139,11 +142,12 @@ func UniversalVerify(c *gin.Context) {
}) })
} }
func setSecureVerificationSession(c *gin.Context) (int64, error) { func setSecureVerificationSession(c *gin.Context, method string) (int64, error) {
session := sessions.Default(c) session := sessions.Default(c)
session.Delete(PasskeyReadySessionKey) session.Delete(PasskeyReadySessionKey)
now := time.Now().Unix() now := time.Now().Unix()
session.Set(SecureVerificationSessionKey, now) session.Set(SecureVerificationSessionKey, now)
session.Set(secureVerificationMethodSessionKey, method)
if err := session.Save(); err != nil { if err := session.Save(); err != nil {
return 0, err return 0, err
} }
+12 -10
View File
@@ -2,11 +2,13 @@ package controller
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"log" "net/http"
"time" "time"
"github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting" "github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/setting/operation_setting"
@@ -24,14 +26,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
// Keep body for debugging consistency (like RequestCreemPay) // Keep body for debugging consistency (like RequestCreemPay)
bodyBytes, err := io.ReadAll(c.Request.Body) bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
log.Printf("read subscription creem pay req body err: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅支付请求读取失败 error=%q", err.Error()))
c.JSON(200, gin.H{"message": "error", "data": "read query error"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "read query error"})
return return
} }
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 { if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return return
} }
@@ -85,12 +87,12 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
PlanId: plan.Id, PlanId: plan.Id,
Money: plan.PriceAmount, Money: plan.PriceAmount,
TradeNo: referenceId, TradeNo: referenceId,
PaymentMethod: PaymentMethodCreem, PaymentMethod: model.PaymentMethodCreem,
CreateTime: time.Now().Unix(), CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending, Status: common.TopUpStatusPending,
} }
if err := order.Insert(); err != nil { if err := order.Insert(); err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return return
} }
@@ -112,14 +114,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
Quota: 0, Quota: 0,
} }
checkoutUrl, err := genCreemLink(referenceId, product, user.Email, user.Username) checkoutUrl, err := genCreemLink(c.Request.Context(), referenceId, product, user.Email, user.Username)
if err != nil { if err != nil {
log.Printf("获取Creem支付链接失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅支付链接创建失败 trade_no=%s product_id=%s error=%q", referenceId, product.ProductId, err.Error()))
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "success", "message": "success",
"data": gin.H{ "data": gin.H{
"checkout_url": checkoutUrl, "checkout_url": checkoutUrl,
+3 -3
View File
@@ -104,7 +104,7 @@ func SubscriptionRequestEpay(c *gin.Context) {
ReturnUrl: returnUrl, ReturnUrl: returnUrl,
}) })
if err != nil { if err != nil {
_ = model.ExpireSubscriptionOrder(tradeNo) _ = model.ExpireSubscriptionOrder(tradeNo, req.PaymentMethod)
common.ApiErrorMsg(c, "拉起支付失败") common.ApiErrorMsg(c, "拉起支付失败")
return return
} }
@@ -156,7 +156,7 @@ func SubscriptionEpayNotify(c *gin.Context) {
LockOrder(verifyInfo.ServiceTradeNo) LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo) defer UnlockOrder(verifyInfo.ServiceTradeNo)
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo)); err != nil { if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
_, _ = c.Writer.Write([]byte("fail")) _, _ = c.Writer.Write([]byte("fail"))
return return
} }
@@ -205,7 +205,7 @@ func SubscriptionEpayReturn(c *gin.Context) {
if verifyInfo.TradeStatus == epay.StatusTradeSuccess { if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
LockOrder(verifyInfo.ServiceTradeNo) LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo) defer UnlockOrder(verifyInfo.ServiceTradeNo)
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo)); err != nil { if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail") c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail")
return return
} }
+3 -3
View File
@@ -2,12 +2,12 @@ package controller
import ( import (
"fmt" "fmt"
"log"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting" "github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/system_setting" "github.com/QuantumNous/new-api/setting/system_setting"
@@ -78,7 +78,7 @@ func SubscriptionRequestStripePay(c *gin.Context) {
payLink, err := genStripeSubscriptionLink(referenceId, user.StripeCustomer, user.Email, plan.StripePriceId) payLink, err := genStripeSubscriptionLink(referenceId, user.StripeCustomer, user.Email, plan.StripePriceId)
if err != nil { if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 订阅支付链接创建失败 trade_no=%s plan_id=%d error=%q", referenceId, plan.Id, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
@@ -88,7 +88,7 @@ func SubscriptionRequestStripePay(c *gin.Context) {
PlanId: plan.Id, PlanId: plan.Id,
Money: plan.PriceAmount, Money: plan.PriceAmount,
TradeNo: referenceId, TradeNo: referenceId,
PaymentMethod: PaymentMethodStripe, PaymentMethod: model.PaymentMethodStripe,
CreateTime: time.Now().Unix(), CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending, Status: common.TopUpStatusPending,
} }
+271 -5
View File
@@ -2,10 +2,12 @@ package controller
import ( import (
"bytes" "bytes"
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
@@ -14,6 +16,8 @@ import (
"github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/glebarez/sqlite" "github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -38,7 +42,36 @@ type tokenKeyResponse struct {
Key string `json:"key"` Key string `json:"key"`
} }
func setupTokenControllerTestDB(t *testing.T) *gorm.DB { type sqliteColumnInfo struct {
Name string `gorm:"column:name"`
Type string `gorm:"column:type"`
}
type legacyToken struct {
Id int `gorm:"primaryKey"`
UserId int `gorm:"index"`
Key string `gorm:"column:key;type:char(48);uniqueIndex"`
Status int `gorm:"default:1"`
Name string `gorm:"index"`
CreatedTime int64 `gorm:"bigint"`
AccessedTime int64 `gorm:"bigint"`
ExpiredTime int64 `gorm:"bigint;default:-1"`
RemainQuota int `gorm:"default:0"`
UnlimitedQuota bool
ModelLimitsEnabled bool
ModelLimits string `gorm:"type:text"`
AllowIps *string `gorm:"default:''"`
UsedQuota int `gorm:"default:0"`
Group string `gorm:"column:group;default:''"`
CrossGroupRetry bool
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (legacyToken) TableName() string {
return "tokens"
}
func openTokenControllerTestDB(t *testing.T) *gorm.DB {
t.Helper() t.Helper()
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
@@ -55,10 +88,6 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
model.DB = db model.DB = db
model.LOG_DB = db model.LOG_DB = db
if err := db.AutoMigrate(&model.Token{}); err != nil {
t.Fatalf("failed to migrate token table: %v", err)
}
t.Cleanup(func() { t.Cleanup(func() {
sqlDB, err := db.DB() sqlDB, err := db.DB()
if err == nil { if err == nil {
@@ -69,6 +98,69 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
return db return db
} }
func migrateTokenControllerTestDB(t *testing.T, db *gorm.DB) {
t.Helper()
if err := db.AutoMigrate(&model.Token{}); err != nil {
t.Fatalf("failed to migrate token table: %v", err)
}
}
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
t.Helper()
db := openTokenControllerTestDB(t)
migrateTokenControllerTestDB(t, db)
return db
}
func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) (*gorm.DB, *bool) {
t.Helper()
gin.SetMode(gin.TestMode)
common.RedisEnabled = false
common.UsingSQLite = false
common.UsingMySQL = dialect == "mysql"
common.UsingPostgreSQL = dialect == "postgres"
var (
db *gorm.DB
err error
)
switch dialect {
case "mysql":
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
case "postgres":
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
default:
t.Fatalf("unsupported dialect %q", dialect)
}
if err != nil {
t.Fatalf("failed to open %s db: %v", dialect, err)
}
model.DB = db
model.LOG_DB = db
if db.Migrator().HasTable("tokens") {
t.Skipf("refusing to run %s migration compatibility test against external database because tokens table already exists", dialect)
}
managedTokensTable := new(bool)
t.Cleanup(func() {
if *managedTokensTable && db.Migrator().HasTable("tokens") {
_ = db.Migrator().DropTable("tokens")
}
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
})
return db, managedTokensTable
}
func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token { func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
t.Helper() t.Helper()
@@ -124,6 +216,180 @@ func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenA
return response return response
} }
func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName string) string {
t.Helper()
var columns []sqliteColumnInfo
if err := db.Raw("PRAGMA table_info(" + tableName + ")").Scan(&columns).Error; err != nil {
t.Fatalf("failed to inspect %s schema: %v", tableName, err)
}
for _, column := range columns {
if column.Name == columnName {
return strings.ToLower(column.Type)
}
}
t.Fatalf("column %s not found in %s schema", columnName, tableName)
return ""
}
func getTokenKeyColumnType(t *testing.T, db *gorm.DB, dialect string) string {
t.Helper()
switch dialect {
case "sqlite":
return getSQLiteColumnType(t, db, "tokens", "key")
case "mysql":
var columnType string
if err := db.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
"tokens", "key").Scan(&columnType).Error; err != nil {
t.Fatalf("failed to inspect mysql token key column: %v", err)
}
return strings.ToLower(columnType)
case "postgres":
var dataType string
var maxLength sql.NullInt64
if err := db.Raw(`SELECT data_type, character_maximum_length
FROM information_schema.columns
WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
"tokens", "key").Row().Scan(&dataType, &maxLength); err != nil {
t.Fatalf("failed to inspect postgres token key column: %v", err)
}
switch strings.ToLower(dataType) {
case "character varying":
return fmt.Sprintf("varchar(%d)", maxLength.Int64)
case "character":
return fmt.Sprintf("char(%d)", maxLength.Int64)
default:
if maxLength.Valid {
return fmt.Sprintf("%s(%d)", strings.ToLower(dataType), maxLength.Int64)
}
return strings.ToLower(dataType)
}
default:
t.Fatalf("unsupported dialect %q", dialect)
return ""
}
}
func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string, managedTokensTable *bool) {
t.Helper()
legacyKey := strings.Repeat("a", 48)
longKey := strings.Repeat("b", 64)
if err := db.AutoMigrate(&legacyToken{}); err != nil {
t.Fatalf("failed to create legacy token schema: %v", err)
}
if managedTokensTable != nil {
*managedTokensTable = true
}
if err := db.Create(&legacyToken{
UserId: 7,
Key: legacyKey,
Status: common.TokenStatusEnabled,
Name: "legacy-token",
CreatedTime: 1,
AccessedTime: 1,
ExpiredTime: -1,
RemainQuota: 100,
UnlimitedQuota: true,
ModelLimitsEnabled: false,
ModelLimits: "",
AllowIps: common.GetPointer(""),
UsedQuota: 0,
Group: "default",
CrossGroupRetry: false,
}).Error; err != nil {
t.Fatalf("failed to seed legacy token row: %v", err)
}
if got := getTokenKeyColumnType(t, db, dialect); got != "char(48)" {
t.Fatalf("expected legacy key column type char(48), got %q", got)
}
migrateTokenControllerTestDB(t, db)
if got := getTokenKeyColumnType(t, db, dialect); got != "varchar(128)" {
t.Fatalf("expected migrated key column type varchar(128), got %q", got)
}
var migratedToken model.Token
if err := db.First(&migratedToken, "name = ?", "legacy-token").Error; err != nil {
t.Fatalf("failed to load migrated token row: %v", err)
}
if migratedToken.Key != legacyKey {
t.Fatalf("expected migrated token key %q, got %q", legacyKey, migratedToken.Key)
}
if migratedToken.Name != "legacy-token" {
t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
}
inserted := model.Token{
UserId: 8,
Name: "long-token",
Key: longKey,
Status: common.TokenStatusEnabled,
CreatedTime: 1,
AccessedTime: 1,
ExpiredTime: -1,
RemainQuota: 200,
UnlimitedQuota: true,
ModelLimitsEnabled: false,
ModelLimits: "",
AllowIps: common.GetPointer(""),
UsedQuota: 0,
Group: "default",
CrossGroupRetry: false,
}
if err := db.Create(&inserted).Error; err != nil {
t.Fatalf("failed to insert long token after migration: %v", err)
}
var fetched model.Token
if err := db.First(&fetched, "id = ?", inserted.Id).Error; err != nil {
t.Fatalf("failed to fetch long token after migration: %v", err)
}
if fetched.Key != longKey {
t.Fatalf("expected long token key %q, got %q", longKey, fetched.Key)
}
}
func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
db := setupTokenControllerTestDB(t)
if got := getTokenKeyColumnType(t, db, "sqlite"); got != "varchar(128)" {
t.Fatalf("expected key column type varchar(128), got %q", got)
}
}
func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
db := openTokenControllerTestDB(t)
runTokenMigrationCompatibilityTest(t, db, "sqlite", nil)
}
func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) {
dsn := os.Getenv("TEST_MYSQL_DSN")
if dsn == "" {
t.Skip("set TEST_MYSQL_DSN to run mysql migration compatibility test")
}
db, managedTokensTable := openTokenControllerExternalDB(t, "mysql", dsn)
runTokenMigrationCompatibilityTest(t, db, "mysql", managedTokensTable)
}
func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) {
dsn := os.Getenv("TEST_POSTGRES_DSN")
if dsn == "" {
t.Skip("set TEST_POSTGRES_DSN to run postgres migration compatibility test")
}
db, managedTokensTable := openTokenControllerExternalDB(t, "postgres", dsn)
runTokenMigrationCompatibilityTest(t, db, "postgres", managedTokensTable)
}
func TestGetAllTokensMasksKeyInResponse(t *testing.T) { func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
db := setupTokenControllerTestDB(t) db := setupTokenControllerTestDB(t)
token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678") token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
+103 -57
View File
@@ -2,7 +2,7 @@ package controller
import ( import (
"fmt" "fmt"
"log" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"sync" "sync"
@@ -27,7 +27,7 @@ func GetTopUpInfo(c *gin.Context) {
payMethods := operation_setting.PayMethods payMethods := operation_setting.PayMethods
// 如果启用了 Stripe 支付,添加到支付方法列表 // 如果启用了 Stripe 支付,添加到支付方法列表
if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" { if isStripeTopUpEnabled() {
// 检查是否已经包含 Stripe // 检查是否已经包含 Stripe
hasStripe := false hasStripe := false
for _, method := range payMethods { for _, method := range payMethods {
@@ -49,19 +49,11 @@ func GetTopUpInfo(c *gin.Context) {
} }
// 如果启用了 Waffo 支付,添加到支付方法列表 // 如果启用了 Waffo 支付,添加到支付方法列表
enableWaffo := setting.WaffoEnabled && enableWaffo := isWaffoTopUpEnabled()
((!setting.WaffoSandbox &&
setting.WaffoApiKey != "" &&
setting.WaffoPrivateKey != "" &&
setting.WaffoPublicCert != "") ||
(setting.WaffoSandbox &&
setting.WaffoSandboxApiKey != "" &&
setting.WaffoSandboxPrivateKey != "" &&
setting.WaffoSandboxPublicCert != ""))
if enableWaffo { if enableWaffo {
hasWaffo := false hasWaffo := false
for _, method := range payMethods { for _, method := range payMethods {
if method["type"] == "waffo" { if method["type"] == model.PaymentMethodWaffo {
hasWaffo = true hasWaffo = true
break break
} }
@@ -70,7 +62,7 @@ func GetTopUpInfo(c *gin.Context) {
if !hasWaffo { if !hasWaffo {
waffoMethod := map[string]string{ waffoMethod := map[string]string{
"name": "Waffo (Global Payment)", "name": "Waffo (Global Payment)",
"type": "waffo", "type": model.PaymentMethodWaffo,
"color": "rgba(var(--semi-blue-5), 1)", "color": "rgba(var(--semi-blue-5), 1)",
"min_topup": strconv.Itoa(setting.WaffoMinTopUp), "min_topup": strconv.Itoa(setting.WaffoMinTopUp),
} }
@@ -78,24 +70,46 @@ func GetTopUpInfo(c *gin.Context) {
} }
} }
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
if enableWaffoPancake {
hasWaffoPancake := false
for _, method := range payMethods {
if method["type"] == model.PaymentMethodWaffoPancake {
hasWaffoPancake = true
break
}
}
if !hasWaffoPancake {
payMethods = append(payMethods, map[string]string{
"name": "Waffo Pancake",
"type": model.PaymentMethodWaffoPancake,
"color": "rgba(var(--semi-orange-5), 1)",
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
})
}
}
data := gin.H{ data := gin.H{
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "", "enable_online_topup": isEpayTopUpEnabled(),
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "", "enable_stripe_topup": isStripeTopUpEnabled(),
"enable_creem_topup": setting.CreemApiKey != "" && setting.CreemProducts != "[]", "enable_creem_topup": isCreemTopUpEnabled(),
"enable_waffo_topup": enableWaffo, "enable_waffo_topup": enableWaffo,
"enable_waffo_pancake_topup": enableWaffoPancake,
"waffo_pay_methods": func() interface{} { "waffo_pay_methods": func() interface{} {
if enableWaffo { if enableWaffo {
return setting.GetWaffoPayMethods() return setting.GetWaffoPayMethods()
} }
return nil return nil
}(), }(),
"creem_products": setting.CreemProducts, "creem_products": setting.CreemProducts,
"pay_methods": payMethods, "pay_methods": payMethods,
"min_topup": operation_setting.MinTopUp, "min_topup": operation_setting.MinTopUp,
"stripe_min_topup": setting.StripeMinTopUp, "stripe_min_topup": setting.StripeMinTopUp,
"waffo_min_topup": setting.WaffoMinTopUp, "waffo_min_topup": setting.WaffoMinTopUp,
"amount_options": operation_setting.GetPaymentSetting().AmountOptions, "waffo_pancake_min_topup": setting.WaffoPancakeMinTopUp,
"discount": operation_setting.GetPaymentSetting().AmountDiscount, "amount_options": operation_setting.GetPaymentSetting().AmountOptions,
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
} }
common.ApiSuccess(c, data) common.ApiSuccess(c, data)
} }
@@ -109,6 +123,17 @@ type AmountRequest struct {
Amount int64 `json:"amount"` Amount int64 `json:"amount"`
} }
var nonEpayPaymentMethodsForCallback = []string{
model.PaymentMethodStripe,
model.PaymentMethodCreem,
model.PaymentMethodWaffo,
model.PaymentMethodWaffoPancake,
}
func isNonEpayPaymentMethodForEpayCallback(paymentMethod string) bool {
return lo.Contains(nonEpayPaymentMethodsForCallback, paymentMethod)
}
func GetEpayClient() *epay.Client { func GetEpayClient() *epay.Client {
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" { if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
return nil return nil
@@ -167,28 +192,28 @@ func RequestEpay(c *gin.Context) {
var req EpayRequest var req EpayRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return return
} }
if req.Amount < getMinTopup() { if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
group, err := model.GetUserGroup(id, true) group, err := model.GetUserGroup(id, true)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return return
} }
payMoney := getPayMoney(req.Amount, group) payMoney := getPayMoney(req.Amount, group)
if payMoney < 0.01 { if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return return
} }
if !operation_setting.ContainsPayMethod(req.PaymentMethod) { if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "支付方式不存在"})
return return
} }
@@ -199,7 +224,7 @@ func RequestEpay(c *gin.Context) {
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo) tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
client := GetEpayClient() client := GetEpayClient()
if client == nil { if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
return return
} }
uri, params, err := client.Purchase(&epay.PurchaseArgs{ uri, params, err := client.Purchase(&epay.PurchaseArgs{
@@ -212,7 +237,8 @@ func RequestEpay(c *gin.Context) {
ReturnUrl: returnUrl, ReturnUrl: returnUrl,
}) })
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 拉起支付失败 user_id=%d trade_no=%s payment_method=%s amount=%d error=%q", id, tradeNo, req.PaymentMethod, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
amount := req.Amount amount := req.Amount
@@ -228,14 +254,16 @@ func RequestEpay(c *gin.Context) {
TradeNo: tradeNo, TradeNo: tradeNo,
PaymentMethod: req.PaymentMethod, PaymentMethod: req.PaymentMethod,
CreateTime: time.Now().Unix(), CreateTime: time.Now().Unix(),
Status: "pending", Status: common.TopUpStatusPending,
} }
err = topUp.Insert() err = topUp.Insert()
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 创建充值订单失败 user_id=%d trade_no=%s payment_method=%s amount=%d error=%q", id, tradeNo, req.PaymentMethod, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return return
} }
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri}) logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 充值订单创建成功 user_id=%d trade_no=%s payment_method=%s amount=%d money=%.2f uri=%q params=%q", id, tradeNo, req.PaymentMethod, req.Amount, payMoney, uri, common.GetJsonString(params)))
c.JSON(http.StatusOK, gin.H{"message": "success", "data": params, "url": uri})
} }
// tradeNo lock // tradeNo lock
@@ -281,12 +309,18 @@ func UnlockOrder(tradeNo string) {
} }
func EpayNotify(c *gin.Context) { func EpayNotify(c *gin.Context) {
if !isEpayWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
_, _ = c.Writer.Write([]byte("fail"))
return
}
var params map[string]string var params map[string]string
if c.Request.Method == "POST" { if c.Request.Method == "POST" {
// POST 请求:从 POST body 解析参数 // POST 请求:从 POST body 解析参数
if err := c.Request.ParseForm(); err != nil { if err := c.Request.ParseForm(); err != nil {
log.Println("易支付回调POST解析失败:", err) logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook POST 表单解析失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
_, _ = c.Writer.Write([]byte("fail")) _, _ = c.Writer.Write([]byte("fail"))
return return
} }
@@ -301,50 +335,63 @@ func EpayNotify(c *gin.Context) {
return r return r
}, map[string]string{}) }, map[string]string{})
} }
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 收到请求 path=%q client_ip=%s method=%s params=%q", c.Request.RequestURI, c.ClientIP(), c.Request.Method, common.GetJsonString(params)))
if len(params) == 0 { if len(params) == 0 {
log.Println("易支付回调参数为空") logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 参数为空 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
_, _ = c.Writer.Write([]byte("fail")) _, _ = c.Writer.Write([]byte("fail"))
return return
} }
client := GetEpayClient() client := GetEpayClient()
if client == nil { if client == nil {
log.Println("易支付回调失败 未找到配置信息") logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 client 未初始化 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
_, err := c.Writer.Write([]byte("fail")) _, err := c.Writer.Write([]byte("fail"))
if err != nil { if err != nil {
log.Println("易支付回调写入失败") logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
} }
return return
} }
verifyInfo, err := client.Verify(params) verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus { if err == nil && verifyInfo.VerifyStatus {
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签成功 trade_no=%s callback_type=%s trade_status=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, verifyInfo.TradeStatus, c.ClientIP(), common.GetJsonString(verifyInfo)))
_, err := c.Writer.Write([]byte("success")) _, err := c.Writer.Write([]byte("success"))
if err != nil { if err != nil {
log.Println("易支付回调写入失败") logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 trade_no=%s client_ip=%s error=%q", verifyInfo.ServiceTradeNo, c.ClientIP(), err.Error()))
} }
} else { } else {
_, err := c.Writer.Write([]byte("fail")) _, err := c.Writer.Write([]byte("fail"))
if err != nil { if err != nil {
log.Println("易支付回调写入失败") logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
}
if err != nil {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签失败 path=%q client_ip=%s verify_error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
} else {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签失败 path=%q client_ip=%s verify_status=false", c.Request.RequestURI, c.ClientIP()))
} }
log.Println("易支付回调签名验证失败")
return return
} }
if verifyInfo.TradeStatus == epay.StatusTradeSuccess { if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
LockOrder(verifyInfo.ServiceTradeNo) LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo) defer UnlockOrder(verifyInfo.ServiceTradeNo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo) topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp == nil { if topUp == nil {
log.Printf("易支付回调未找到订单: %v", verifyInfo) logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo)))
return return
} }
if topUp.Status == "pending" { if isNonEpayPaymentMethodForEpayCallback(topUp.PaymentMethod) {
topUp.Status = "success" logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
return
}
if topUp.PaymentMethod != verifyInfo.Type {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
return
}
if topUp.Status == common.TopUpStatusPending {
topUp.Status = common.TopUpStatusSuccess
err := topUp.Update() err := topUp.Update()
if err != nil { if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp) logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 更新充值订单失败 trade_no=%s user_id=%d client_ip=%s error=%q topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), err.Error(), common.GetJsonString(topUp)))
return return
} }
//user, _ := model.GetUserById(topUp.UserId, false) //user, _ := model.GetUserById(topUp.UserId, false)
@@ -354,14 +401,14 @@ func EpayNotify(c *gin.Context) {
quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart()) quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart())
err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true) err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true)
if err != nil { if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp) logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 更新用户额度失败 trade_no=%s user_id=%d client_ip=%s quota_to_add=%d error=%q topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), quotaToAdd, err.Error(), common.GetJsonString(topUp)))
return return
} }
log.Printf("易支付回调更新用户成功 %v", topUp) logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 充值成功 trade_no=%s user_id=%d client_ip=%s quota_to_add=%d money=%.2f topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), quotaToAdd, topUp.Money, common.GetJsonString(topUp)))
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money)) model.RecordTopupLog(topUp.UserId, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money), c.ClientIP(), topUp.PaymentMethod, "epay")
} }
} else { } else {
log.Printf("易支付异常回调: %v", verifyInfo) logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 忽略事件 trade_no=%s callback_type=%s trade_status=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, verifyInfo.TradeStatus, c.ClientIP(), common.GetJsonString(verifyInfo)))
} }
} }
@@ -369,26 +416,26 @@ func RequestAmount(c *gin.Context) {
var req AmountRequest var req AmountRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return return
} }
if req.Amount < getMinTopup() { if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
group, err := model.GetUserGroup(id, true) group, err := model.GetUserGroup(id, true)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return return
} }
payMoney := getPayMoney(req.Amount, group) payMoney := getPayMoney(req.Amount, group)
if payMoney <= 0.01 { if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return return
} }
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)}) c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
} }
func GetUserTopUps(c *gin.Context) { func GetUserTopUps(c *gin.Context) {
@@ -457,10 +504,9 @@ func AdminCompleteTopUp(c *gin.Context) {
LockOrder(req.TradeNo) LockOrder(req.TradeNo)
defer UnlockOrder(req.TradeNo) defer UnlockOrder(req.TradeNo)
if err := model.ManualCompleteTopUp(req.TradeNo); err != nil { if err := model.ManualCompleteTopUp(req.TradeNo, c.ClientIP()); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
common.ApiSuccess(c, nil) common.ApiSuccess(c, nil)
} }
+63 -71
View File
@@ -2,6 +2,7 @@ package controller
import ( import (
"bytes" "bytes"
"context"
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
@@ -9,10 +10,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting" "github.com/QuantumNous/new-api/setting"
"io" "io"
"log"
"net/http" "net/http"
"time" "time"
@@ -20,10 +21,7 @@ import (
"github.com/thanhpk/randstr" "github.com/thanhpk/randstr"
) )
const ( const CreemSignatureHeader = "creem-signature"
PaymentMethodCreem = "creem"
CreemSignatureHeader = "creem-signature"
)
var creemAdaptor = &CreemAdaptor{} var creemAdaptor = &CreemAdaptor{}
@@ -37,9 +35,9 @@ func generateCreemSignature(payload string, secret string) string {
// 验证Creem webhook签名 // 验证Creem webhook签名
func verifyCreemSignature(payload string, signature string, secret string) bool { func verifyCreemSignature(payload string, signature string, secret string) bool {
if secret == "" { if secret == "" {
log.Printf("Creem webhook secret not set") logger.LogWarn(context.Background(), fmt.Sprintf("Creem webhook secret 未配置 test_mode=%t signature=%q body=%q", setting.CreemTestMode, signature, payload))
if setting.CreemTestMode { if setting.CreemTestMode {
log.Printf("Skip Creem webhook sign verify in test mode") logger.LogInfo(context.Background(), fmt.Sprintf("Creem webhook 验签已跳过 reason=test_mode signature=%q body=%q", signature, payload))
return true return true
} }
return false return false
@@ -66,13 +64,13 @@ type CreemAdaptor struct {
} }
func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) { func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
if req.PaymentMethod != PaymentMethodCreem { if req.PaymentMethod != model.PaymentMethodCreem {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付渠道"})
return return
} }
if req.ProductId == "" { if req.ProductId == "" {
c.JSON(200, gin.H{"message": "error", "data": "请选择产品"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "请选择产品"})
return return
} }
@@ -80,8 +78,8 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
var products []CreemProduct var products []CreemProduct
err := json.Unmarshal([]byte(setting.CreemProducts), &products) err := json.Unmarshal([]byte(setting.CreemProducts), &products)
if err != nil { if err != nil {
log.Println("解析Creem产品列表失败", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 产品配置解析失败 user_id=%d error=%q", c.GetInt("id"), err.Error()))
c.JSON(200, gin.H{"message": "error", "data": "产品配置错误"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "产品配置错误"})
return return
} }
@@ -95,7 +93,7 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
} }
if selectedProduct == nil { if selectedProduct == nil {
c.JSON(200, gin.H{"message": "error", "data": "产品不存在"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "产品不存在"})
return return
} }
@@ -108,32 +106,32 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
// 先创建订单记录,使用产品配置的金额和充值额度 // 先创建订单记录,使用产品配置的金额和充值额度
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: selectedProduct.Quota, // 充值额度 Amount: selectedProduct.Quota, // 充值额度
Money: selectedProduct.Price, // 支付金额 Money: selectedProduct.Price, // 支付金额
TradeNo: referenceId, TradeNo: referenceId,
CreateTime: time.Now().Unix(), PaymentMethod: model.PaymentMethodCreem,
Status: common.TopUpStatusPending, CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
} }
err = topUp.Insert() err = topUp.Insert()
if err != nil { if err != nil {
log.Printf("创建Creem订单失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 创建充值订单失败 user_id=%d trade_no=%s product_id=%s error=%q", id, referenceId, selectedProduct.ProductId, err.Error()))
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return return
} }
// 创建支付链接,传入用户邮箱 // 创建支付链接,传入用户邮箱
checkoutUrl, err := genCreemLink(referenceId, selectedProduct, user.Email, user.Username) checkoutUrl, err := genCreemLink(c.Request.Context(), referenceId, selectedProduct, user.Email, user.Username)
if err != nil { if err != nil {
log.Printf("获取Creem支付链接失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 创建支付链接失败 user_id=%d trade_no=%s product_id=%s error=%q", id, referenceId, selectedProduct.ProductId, err.Error()))
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
log.Printf("Creem订单创建成功 - 用户ID: %d, 订单号: %s, 产品: %s, 充值额度: %d, 支付金额: %.2f", logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值订单创建成功 user_id=%d trade_no=%s product_id=%s product_name=%q quota=%d money=%.2f", id, referenceId, selectedProduct.ProductId, selectedProduct.Name, selectedProduct.Quota, selectedProduct.Price))
id, referenceId, selectedProduct.Name, selectedProduct.Quota, selectedProduct.Price)
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "success", "message": "success",
"data": gin.H{ "data": gin.H{
"checkout_url": checkoutUrl, "checkout_url": checkoutUrl,
@@ -148,20 +146,19 @@ func RequestCreemPay(c *gin.Context) {
// 读取body内容用于打印,同时保留原始数据供后续使用 // 读取body内容用于打印,同时保留原始数据供后续使用
bodyBytes, err := io.ReadAll(c.Request.Body) bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
log.Printf("read creem pay req body err: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 支付请求读取失败 error=%q", err.Error()))
c.JSON(200, gin.H{"message": "error", "data": "read query error"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "read query error"})
return return
} }
// 打印body内容 logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 支付请求已收到 user_id=%d body=%q", c.GetInt("id"), string(bodyBytes)))
log.Printf("creem pay request body: %s", string(bodyBytes))
// 重新设置body供后续的ShouldBindJSON使用 // 重新设置body供后续的ShouldBindJSON使用
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
err = c.ShouldBindJSON(&req) err = c.ShouldBindJSON(&req)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return return
} }
creemAdaptor.RequestPay(c, &req) creemAdaptor.RequestPay(c, &req)
@@ -229,35 +226,37 @@ type CreemWebhookEvent struct {
} }
func CreemWebhook(c *gin.Context) { func CreemWebhook(c *gin.Context) {
if !isCreemWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
// 读取body内容用于打印,同时保留原始数据供后续使用 // 读取body内容用于打印,同时保留原始数据供后续使用
bodyBytes, err := io.ReadAll(c.Request.Body) bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
log.Printf("读取Creem Webhook请求body失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Creem webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusBadRequest) c.AbortWithStatus(http.StatusBadRequest)
return return
} }
// 获取签名头 // 获取签名头
signature := c.GetHeader(CreemSignatureHeader) signature := c.GetHeader(CreemSignatureHeader)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
// 打印关键信息(避免输出完整敏感payload) if signature == "" {
log.Printf("Creem Webhook - URI: %s", c.Request.RequestURI) logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 缺少签名 path=%q client_ip=%s body=%q", c.Request.RequestURI, c.ClientIP(), string(bodyBytes)))
if setting.CreemTestMode {
log.Printf("Creem Webhook - Signature: %s , Body: %s", signature, bodyBytes)
} else if signature == "" {
log.Printf("Creem Webhook缺少签名头")
c.AbortWithStatus(http.StatusUnauthorized) c.AbortWithStatus(http.StatusUnauthorized)
return return
} }
// 验证签名 // 验证签名
if !verifyCreemSignature(string(bodyBytes), signature, setting.CreemWebhookSecret) { if !verifyCreemSignature(string(bodyBytes), signature, setting.CreemWebhookSecret) {
log.Printf("Creem Webhook签名验证失败") logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 验签失败 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
c.AbortWithStatus(http.StatusUnauthorized) c.AbortWithStatus(http.StatusUnauthorized)
return return
} }
log.Printf("Creem Webhook签名验证成功") logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 验签成功 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
// 重新设置body供后续的ShouldBindJSON使用 // 重新设置body供后续的ShouldBindJSON使用
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
@@ -265,19 +264,19 @@ func CreemWebhook(c *gin.Context) {
// 解析新格式的webhook数据 // 解析新格式的webhook数据
var webhookEvent CreemWebhookEvent var webhookEvent CreemWebhookEvent
if err := c.ShouldBindJSON(&webhookEvent); err != nil { if err := c.ShouldBindJSON(&webhookEvent); err != nil {
log.Printf("解析Creem Webhook参数失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Creem webhook 解析失败 path=%q client_ip=%s error=%q body=%q", c.Request.RequestURI, c.ClientIP(), err.Error(), string(bodyBytes)))
c.AbortWithStatus(http.StatusBadRequest) c.AbortWithStatus(http.StatusBadRequest)
return return
} }
log.Printf("Creem Webhook解析成功 - EventType: %s, EventId: %s", webhookEvent.EventType, webhookEvent.Id) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 解析成功 event_type=%s event_id=%s request_id=%s order_id=%s order_status=%s", webhookEvent.EventType, webhookEvent.Id, webhookEvent.Object.RequestId, webhookEvent.Object.Order.Id, webhookEvent.Object.Order.Status))
// 根据事件类型处理不同的webhook // 根据事件类型处理不同的webhook
switch webhookEvent.EventType { switch webhookEvent.EventType {
case "checkout.completed": case "checkout.completed":
handleCheckoutCompleted(c, &webhookEvent) handleCheckoutCompleted(c, &webhookEvent)
default: default:
log.Printf("忽略Creem Webhook事件类型: %s", webhookEvent.EventType) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 忽略事件 event_type=%s event_id=%s", webhookEvent.EventType, webhookEvent.Id))
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
} }
@@ -286,7 +285,7 @@ func CreemWebhook(c *gin.Context) {
func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) { func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// 验证订单状态 // 验证订单状态
if event.Object.Order.Status != "paid" { if event.Object.Order.Status != "paid" {
log.Printf("订单状态不是已支付: %s, 跳过处理", event.Object.Order.Status) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订单状态未支付,忽略处理 request_id=%s order_id=%s order_status=%s", event.Object.RequestId, event.Object.Order.Id, event.Object.Order.Status))
c.Status(http.StatusOK) c.Status(http.StatusOK)
return return
} }
@@ -294,7 +293,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// 获取引用ID(这是我们创建订单时传递的request_id) // 获取引用ID(这是我们创建订单时传递的request_id)
referenceId := event.Object.RequestId referenceId := event.Object.RequestId
if referenceId == "" { if referenceId == "" {
log.Println("Creem Webhook缺少request_id字段") logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 缺少 request_id event_id=%s order_id=%s", event.Id, event.Object.Order.Id))
c.AbortWithStatus(http.StatusBadRequest) c.AbortWithStatus(http.StatusBadRequest)
return return
} }
@@ -302,40 +301,35 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// Try complete subscription order first // Try complete subscription order first
LockOrder(referenceId) LockOrder(referenceId)
defer UnlockOrder(referenceId) defer UnlockOrder(referenceId)
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event)); err == nil { if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentMethodCreem); err == nil {
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
c.Status(http.StatusOK) c.Status(http.StatusOK)
return return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
log.Printf("Creem订阅订单处理失败: %s, 订单号: %s", err.Error(), referenceId) logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理失败 trade_no=%s creem_order_id=%s error=%q", referenceId, event.Object.Order.Id, err.Error()))
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
// 验证订单类型,目前只处理一次性付款(充值) // 验证订单类型,目前只处理一次性付款(充值)
if event.Object.Order.Type != "onetime" { if event.Object.Order.Type != "onetime" {
log.Printf("暂不支持订单类型: %s, 跳过处理", event.Object.Order.Type) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 暂不支持订单类型,忽略处理 request_id=%s creem_order_id=%s order_type=%s", referenceId, event.Object.Order.Id, event.Object.Order.Type))
c.Status(http.StatusOK) c.Status(http.StatusOK)
return return
} }
// 记录详细的支付信息 logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 支付完成回调 trade_no=%s creem_order_id=%s amount_paid=%d currency=%s product_name=%q customer_email=%q customer_name=%q", referenceId, event.Object.Order.Id, event.Object.Order.AmountPaid, event.Object.Order.Currency, event.Object.Product.Name, event.Object.Customer.Email, event.Object.Customer.Name))
log.Printf("处理Creem支付完成 - 订单号: %s, Creem订单ID: %s, 支付金额: %d %s, 客户邮箱: <redacted>, 产品: %s",
referenceId,
event.Object.Order.Id,
event.Object.Order.AmountPaid,
event.Object.Order.Currency,
event.Object.Product.Name)
// 查询本地订单确认存在 // 查询本地订单确认存在
topUp := model.GetTopUpByTradeNo(referenceId) topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil { if topUp == nil {
log.Printf("Creem充值订单不存在: %s", referenceId) logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 充值订单不存在 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
c.AbortWithStatus(http.StatusBadRequest) c.AbortWithStatus(http.StatusBadRequest)
return return
} }
if topUp.Status != common.TopUpStatusPending { if topUp.Status != common.TopUpStatusPending {
log.Printf("Creem充值订单状态错误: %s, 当前状态: %s", referenceId, topUp.Status) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值订单状态非 pending,忽略处理 trade_no=%s status=%s creem_order_id=%s", referenceId, topUp.Status, event.Object.Order.Id))
c.Status(http.StatusOK) // 已处理过的订单,返回成功避免重复处理 c.Status(http.StatusOK) // 已处理过的订单,返回成功避免重复处理
return return
} }
@@ -346,21 +340,20 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// 防护性检查,确保邮箱和姓名不为空字符串 // 防护性检查,确保邮箱和姓名不为空字符串
if customerEmail == "" { if customerEmail == "" {
log.Printf("警告:Creem回调客户邮箱为空 - 订单号: %s", referenceId) logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 回调客户邮箱为空 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
} }
if customerName == "" { if customerName == "" {
log.Printf("警告:Creem回调客户姓名为空 - 订单号: %s", referenceId) logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 回调客户姓名为空 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
} }
err := model.RechargeCreem(referenceId, customerEmail, customerName) err := model.RechargeCreem(referenceId, customerEmail, customerName, c.ClientIP())
if err != nil { if err != nil {
log.Printf("Creem充值处理失败: %s, 订单号: %s", err.Error(), referenceId) logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 充值处理失败 trade_no=%s creem_order_id=%s client_ip=%s error=%q", referenceId, event.Object.Order.Id, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
log.Printf("Creem充值成功 - 订单号: %s, 充值额度: %d, 支付金额: %.2f", logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值成功 trade_no=%s creem_order_id=%s quota=%d money=%.2f client_ip=%s", referenceId, event.Object.Order.Id, topUp.Amount, topUp.Money, c.ClientIP()))
referenceId, topUp.Amount, topUp.Money)
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
@@ -378,7 +371,7 @@ type CreemCheckoutResponse struct {
Id string `json:"id"` Id string `json:"id"`
} }
func genCreemLink(referenceId string, product *CreemProduct, email string, username string) (string, error) { func genCreemLink(ctx context.Context, referenceId string, product *CreemProduct, email string, username string) (string, error) {
if setting.CreemApiKey == "" { if setting.CreemApiKey == "" {
return "", fmt.Errorf("未配置Creem API密钥") return "", fmt.Errorf("未配置Creem API密钥")
} }
@@ -387,7 +380,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
apiUrl := "https://api.creem.io/v1/checkouts" apiUrl := "https://api.creem.io/v1/checkouts"
if setting.CreemTestMode { if setting.CreemTestMode {
apiUrl = "https://test-api.creem.io/v1/checkouts" apiUrl = "https://test-api.creem.io/v1/checkouts"
log.Printf("使用Creem测试环境: %s", apiUrl) logger.LogInfo(ctx, fmt.Sprintf("Creem 使用测试环境 api_url=%s", apiUrl))
} }
// 构建请求数据,确保包含用户邮箱 // 构建请求数据,确保包含用户邮箱
@@ -423,8 +416,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", setting.CreemApiKey) req.Header.Set("x-api-key", setting.CreemApiKey)
log.Printf("发送Creem支付请求 - URL: %s, 产品ID: %s, 用户邮箱: %s, 订单号: %s", logger.LogInfo(ctx, fmt.Sprintf("Creem 支付请求已发送 api_url=%s product_id=%s email=%q trade_no=%s", apiUrl, product.ProductId, email, referenceId))
apiUrl, product.ProductId, email, referenceId)
// 发送请求 // 发送请求
client := &http.Client{ client := &http.Client{
@@ -442,7 +434,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
return "", fmt.Errorf("读取响应失败: %v", err) return "", fmt.Errorf("读取响应失败: %v", err)
} }
log.Printf("Creem API resp - status code: %d, resp: %s", resp.StatusCode, string(body)) logger.LogInfo(ctx, fmt.Sprintf("Creem API 响应已收到 trade_no=%s status_code=%d body=%q", referenceId, resp.StatusCode, string(body)))
// 检查响应状态 // 检查响应状态
if resp.StatusCode/100 != 2 { if resp.StatusCode/100 != 2 {
@@ -459,6 +451,6 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
return "", fmt.Errorf("Creem API resp no checkout url ") return "", fmt.Errorf("Creem API resp no checkout url ")
} }
log.Printf("Creem 支付链接创建成功 - 订单号: %s, 支付链接: %s", referenceId, checkoutResp.CheckoutUrl) logger.LogInfo(ctx, fmt.Sprintf("Creem 支付链接创建成功 trade_no=%s response_id=%s checkout_url=%q", referenceId, checkoutResp.Id, checkoutResp.CheckoutUrl))
return checkoutResp.CheckoutUrl, nil return checkoutResp.CheckoutUrl, nil
} }
+31
View File
@@ -0,0 +1,31 @@
package controller
import (
"testing"
"github.com/QuantumNous/new-api/model"
)
func TestIsNonEpayPaymentMethodForEpayCallback(t *testing.T) {
testCases := []struct {
name string
paymentMethod string
expectedBlocked bool
}{
{name: "stripe", paymentMethod: model.PaymentMethodStripe, expectedBlocked: true},
{name: "creem", paymentMethod: model.PaymentMethodCreem, expectedBlocked: true},
{name: "waffo", paymentMethod: model.PaymentMethodWaffo, expectedBlocked: true},
{name: "waffo pancake", paymentMethod: model.PaymentMethodWaffoPancake, expectedBlocked: true},
{name: "alipay", paymentMethod: "alipay", expectedBlocked: false},
{name: "wxpay", paymentMethod: "wxpay", expectedBlocked: false},
{name: "custom epay type", paymentMethod: "custom1", expectedBlocked: false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if actual := isNonEpayPaymentMethodForEpayCallback(tc.paymentMethod); actual != tc.expectedBlocked {
t.Fatalf("expected blocked=%v, got %v for payment method %q", tc.expectedBlocked, actual, tc.paymentMethod)
}
})
}
}
+122 -52
View File
@@ -1,16 +1,17 @@
package controller package controller
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting" "github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/setting/operation_setting"
@@ -23,10 +24,6 @@ import (
"github.com/thanhpk/randstr" "github.com/thanhpk/randstr"
) )
const (
PaymentMethodStripe = "stripe"
)
var stripeAdaptor = &StripeAdaptor{} var stripeAdaptor = &StripeAdaptor{}
// StripePayRequest represents a payment request for Stripe checkout. // StripePayRequest represents a payment request for Stripe checkout.
@@ -48,34 +45,34 @@ type StripeAdaptor struct {
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) { func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
if req.Amount < getStripeMinTopup() { if req.Amount < getStripeMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
group, err := model.GetUserGroup(id, true) group, err := model.GetUserGroup(id, true)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return return
} }
payMoney := getStripePayMoney(float64(req.Amount), group) payMoney := getStripePayMoney(float64(req.Amount), group)
if payMoney <= 0.01 { if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return return
} }
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)}) c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
} }
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) { func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
if req.PaymentMethod != PaymentMethodStripe { if req.PaymentMethod != model.PaymentMethodStripe {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付渠道"})
return return
} }
if req.Amount < getStripeMinTopup() { if req.Amount < getStripeMinTopup() {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10}) c.JSON(http.StatusOK, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
return return
} }
if req.Amount > 10000 { if req.Amount > 10000 {
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10}) c.JSON(http.StatusOK, gin.H{"message": "充值数量不能大于 10000", "data": 10})
return return
} }
@@ -98,8 +95,8 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount, req.SuccessURL, req.CancelURL) payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount, req.SuccessURL, req.CancelURL)
if err != nil { if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 创建 Checkout Session 失败 user_id=%d trade_no=%s amount=%d error=%q", id, referenceId, req.Amount, err.Error()))
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
@@ -108,16 +105,18 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
Amount: req.Amount, Amount: req.Amount,
Money: chargedMoney, Money: chargedMoney,
TradeNo: referenceId, TradeNo: referenceId,
PaymentMethod: PaymentMethodStripe, PaymentMethod: model.PaymentMethodStripe,
CreateTime: time.Now().Unix(), CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending, Status: common.TopUpStatusPending,
} }
err = topUp.Insert() err = topUp.Insert()
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, referenceId, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return return
} }
c.JSON(200, gin.H{ logger.LogInfo(c.Request.Context(), fmt.Sprintf("Stripe 充值订单创建成功 user_id=%d trade_no=%s amount=%d money=%.2f", id, referenceId, req.Amount, chargedMoney))
c.JSON(http.StatusOK, gin.H{
"message": "success", "message": "success",
"data": gin.H{ "data": gin.H{
"pay_link": payLink, "pay_link": payLink,
@@ -129,7 +128,7 @@ func RequestStripeAmount(c *gin.Context) {
var req StripePayRequest var req StripePayRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return return
} }
stripeAdaptor.RequestAmount(c, &req) stripeAdaptor.RequestAmount(c, &req)
@@ -139,54 +138,130 @@ func RequestStripePay(c *gin.Context) {
var req StripePayRequest var req StripePayRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return return
} }
stripeAdaptor.RequestPay(c, &req) stripeAdaptor.RequestPay(c, &req)
} }
func StripeWebhook(c *gin.Context) { func StripeWebhook(c *gin.Context) {
ctx := c.Request.Context()
if !isStripeWebhookEnabled() {
logger.LogWarn(ctx, fmt.Sprintf("Stripe webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
payload, err := io.ReadAll(c.Request.Body) payload, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
log.Printf("解析Stripe Webhook参数失败: %v\n", err) logger.LogError(ctx, fmt.Sprintf("Stripe webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusServiceUnavailable) c.AbortWithStatus(http.StatusServiceUnavailable)
return return
} }
signature := c.GetHeader("Stripe-Signature") signature := c.GetHeader("Stripe-Signature")
endpointSecret := setting.StripeWebhookSecret logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(payload)))
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{ event, err := webhook.ConstructEventWithOptions(payload, signature, setting.StripeWebhookSecret, webhook.ConstructEventOptions{
IgnoreAPIVersionMismatch: true, IgnoreAPIVersionMismatch: true,
}) })
if err != nil { if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err) logger.LogWarn(ctx, fmt.Sprintf("Stripe webhook 验签失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusBadRequest) c.AbortWithStatus(http.StatusBadRequest)
return return
} }
callerIp := c.ClientIP()
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 验签成功 event_type=%s client_ip=%s path=%q", string(event.Type), callerIp, c.Request.RequestURI))
switch event.Type { switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted: case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(event) sessionCompleted(ctx, event, callerIp)
case stripe.EventTypeCheckoutSessionExpired: case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(event) sessionExpired(ctx, event)
case stripe.EventTypeCheckoutSessionAsyncPaymentSucceeded:
sessionAsyncPaymentSucceeded(ctx, event, callerIp)
case stripe.EventTypeCheckoutSessionAsyncPaymentFailed:
sessionAsyncPaymentFailed(ctx, event, callerIp)
default: default:
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type) logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 忽略事件 event_type=%s client_ip=%s", string(event.Type), callerIp))
} }
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
func sessionCompleted(event stripe.Event) { func sessionCompleted(ctx context.Context, event stripe.Event, callerIp string) {
customerId := event.GetObjectValue("customer") customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id") referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status") status := event.GetObjectValue("status")
if "complete" != status { if "complete" != status {
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId) logger.LogWarn(ctx, fmt.Sprintf("Stripe checkout.completed 状态异常,忽略处理 trade_no=%s status=%s client_ip=%s", referenceId, status, callerIp))
return
}
paymentStatus := event.GetObjectValue("payment_status")
if paymentStatus != "paid" {
logger.LogInfo(ctx, fmt.Sprintf("Stripe Checkout 支付未完成,等待异步结果 trade_no=%s payment_status=%s client_ip=%s", referenceId, paymentStatus, callerIp))
return
}
fulfillOrder(ctx, event, referenceId, customerId, callerIp)
}
// sessionAsyncPaymentSucceeded handles delayed payment methods (bank transfer, SEPA, etc.)
// that confirm payment after the checkout session completes.
func sessionAsyncPaymentSucceeded(ctx context.Context, event stripe.Event, callerIp string) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
logger.LogInfo(ctx, fmt.Sprintf("Stripe 异步支付成功 trade_no=%s client_ip=%s", referenceId, callerIp))
fulfillOrder(ctx, event, referenceId, customerId, callerIp)
}
// sessionAsyncPaymentFailed marks orders as failed when delayed payment methods
// ultimately fail (e.g. bank transfer not received, SEPA rejected).
func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp string) {
referenceId := event.GetObjectValue("client_reference_id")
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败 trade_no=%s client_ip=%s", referenceId, callerIp))
if len(referenceId) == 0 {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败事件缺少订单号 client_ip=%s", callerIp))
return
}
LockOrder(referenceId)
defer UnlockOrder(referenceId)
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但本地订单不存在 trade_no=%s client_ip=%s", referenceId, callerIp))
return
}
if topUp.PaymentMethod != model.PaymentMethodStripe {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付方式不匹配 trade_no=%s payment_method=%s client_ip=%s", referenceId, topUp.PaymentMethod, callerIp))
return
}
if topUp.Status != common.TopUpStatusPending {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 异步支付失败但订单状态非 pending,忽略处理 trade_no=%s status=%s client_ip=%s", referenceId, topUp.Status, callerIp))
return
}
topUp.Status = common.TopUpStatusFailed
if err := topUp.Update(); err != nil {
logger.LogError(ctx, fmt.Sprintf("Stripe 标记充值订单失败状态失败 trade_no=%s client_ip=%s error=%q", referenceId, callerIp, err.Error()))
return
}
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值订单已标记为失败 trade_no=%s client_ip=%s", referenceId, callerIp))
}
// fulfillOrder is the shared logic for crediting quota after payment is confirmed.
func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, customerId string, callerIp string) {
if len(referenceId) == 0 {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 完成订单时缺少订单号 client_ip=%s", callerIp))
return return
} }
// Try complete subscription order first
LockOrder(referenceId) LockOrder(referenceId)
defer UnlockOrder(referenceId) defer UnlockOrder(referenceId)
payload := map[string]any{ payload := map[string]any{
@@ -195,65 +270,60 @@ func sessionCompleted(event stripe.Event) {
"currency": strings.ToUpper(event.GetObjectValue("currency")), "currency": strings.ToUpper(event.GetObjectValue("currency")),
"event_type": string(event.Type), "event_type": string(event.Type),
} }
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload)); err == nil { if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentMethodStripe); err == nil {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp))
return return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
log.Println("complete subscription order failed:", err.Error(), referenceId) logger.LogError(ctx, fmt.Sprintf("Stripe 订阅订单处理失败 trade_no=%s event_type=%s client_ip=%s error=%q", referenceId, string(event.Type), callerIp, err.Error()))
return return
} }
err := model.Recharge(referenceId, customerId) err := model.Recharge(referenceId, customerId, callerIp)
if err != nil { if err != nil {
log.Println(err.Error(), referenceId) logger.LogError(ctx, fmt.Sprintf("Stripe 充值处理失败 trade_no=%s event_type=%s client_ip=%s error=%q", referenceId, string(event.Type), callerIp, err.Error()))
return return
} }
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64) total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency")) currency := strings.ToUpper(event.GetObjectValue("currency"))
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency) logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值成功 trade_no=%s amount_total=%.2f currency=%s event_type=%s client_ip=%s", referenceId, total/100, currency, string(event.Type), callerIp))
} }
func sessionExpired(event stripe.Event) { func sessionExpired(ctx context.Context, event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id") referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status") status := event.GetObjectValue("status")
if "expired" != status { if "expired" != status {
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId) logger.LogWarn(ctx, fmt.Sprintf("Stripe checkout.expired 状态异常,忽略处理 trade_no=%s status=%s", referenceId, status))
return return
} }
if len(referenceId) == 0 { if len(referenceId) == 0 {
log.Println("未提供支付单号") logger.LogWarn(ctx, "Stripe checkout.expired 缺少订单号")
return return
} }
// Subscription order expiration // Subscription order expiration
LockOrder(referenceId) LockOrder(referenceId)
defer UnlockOrder(referenceId) defer UnlockOrder(referenceId)
if err := model.ExpireSubscriptionOrder(referenceId); err == nil { if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentMethodStripe); err == nil {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId))
return return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
log.Println("过期订阅订单失败", referenceId, ", err:", err.Error()) logger.LogError(ctx, fmt.Sprintf("Stripe 订阅订单过期处理失败 trade_no=%s error=%q", referenceId, err.Error()))
return return
} }
topUp := model.GetTopUpByTradeNo(referenceId) err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentMethodStripe, common.TopUpStatusExpired)
if topUp == nil { if errors.Is(err, model.ErrTopUpNotFound) {
log.Println("充值订单不存在", referenceId) logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId))
return return
} }
if topUp.Status != common.TopUpStatusPending {
log.Println("充值订单状态错误", referenceId)
}
topUp.Status = common.TopUpStatusExpired
err := topUp.Update()
if err != nil { if err != nil {
log.Println("过期充值订单失败", referenceId, ", err:", err.Error()) logger.LogError(ctx, fmt.Sprintf("Stripe 充值订单过期处理失败 trade_no=%s error=%q", referenceId, err.Error()))
return return
} }
log.Println("充值订单已过期", referenceId) logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值订单已过期 trade_no=%s", referenceId))
} }
// genStripeLink generates a Stripe Checkout session URL for payment. // genStripeLink generates a Stripe Checkout session URL for payment.
+73 -36
View File
@@ -1,14 +1,15 @@
package controller package controller
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
"github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting" "github.com/QuantumNous/new-api/setting"
@@ -99,28 +100,57 @@ type WaffoPayRequest struct {
PayMethodName string `json:"pay_method_name"` // Deprecated: 兼容旧前端,优先使用 pay_method_index PayMethodName string `json:"pay_method_name"` // Deprecated: 兼容旧前端,优先使用 pay_method_index
} }
func RequestWaffoAmount(c *gin.Context) {
var req WaffoPayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
waffoMinTopup := int64(setting.WaffoMinTopUp)
if req.Amount < waffoMinTopup {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getWaffoPayMoney(float64(req.Amount), group)
if payMoney <= 0.01 {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}
// RequestWaffoPay 创建 Waffo 支付订单 // RequestWaffoPay 创建 Waffo 支付订单
func RequestWaffoPay(c *gin.Context) { func RequestWaffoPay(c *gin.Context) {
if !setting.WaffoEnabled { if !setting.WaffoEnabled {
c.JSON(200, gin.H{"message": "error", "data": "Waffo 支付未启用"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo 支付未启用"})
return return
} }
var req WaffoPayRequest var req WaffoPayRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return return
} }
waffoMinTopup := int64(setting.WaffoMinTopUp) waffoMinTopup := int64(setting.WaffoMinTopUp)
if req.Amount < waffoMinTopup { if req.Amount < waffoMinTopup {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
user, err := model.GetUserById(id, false) user, err := model.GetUserById(id, false)
if err != nil || user == nil { if err != nil || user == nil {
c.JSON(200, gin.H{"message": "error", "data": "用户不存在"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "用户不存在"})
return return
} }
@@ -131,8 +161,8 @@ func RequestWaffoPay(c *gin.Context) {
// 新协议:按索引查找 // 新协议:按索引查找
idx := *req.PayMethodIndex idx := *req.PayMethodIndex
if idx < 0 || idx >= len(methods) { if idx < 0 || idx >= len(methods) {
log.Printf("Waffo 无效的支付方式索引: %d, UserId=%d, 可用范围: [0, %d)", idx, id, len(methods)) logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 支付方式索引无效 user_id=%d pay_method_index=%d method_count=%d", id, idx, len(methods)))
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付方式"})
return return
} }
resolvedPayMethodType = methods[idx].PayMethodType resolvedPayMethodType = methods[idx].PayMethodType
@@ -149,8 +179,8 @@ func RequestWaffoPay(c *gin.Context) {
} }
} }
if !valid { if !valid {
log.Printf("Waffo 无效的支付方式: PayMethodType=%s, PayMethodName=%s, UserId=%d", req.PayMethodType, req.PayMethodName, id) logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 支付方式无效 user_id=%d pay_method_type=%s pay_method_name=%q", id, req.PayMethodType, req.PayMethodName))
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付方式"})
return return
} }
} }
@@ -159,7 +189,7 @@ func RequestWaffoPay(c *gin.Context) {
group, _ := model.GetUserGroup(id, true) group, _ := model.GetUserGroup(id, true)
payMoney := getWaffoPayMoney(float64(req.Amount), group) payMoney := getWaffoPayMoney(float64(req.Amount), group)
if payMoney < 0.01 { if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return return
} }
@@ -182,22 +212,22 @@ func RequestWaffoPay(c *gin.Context) {
Amount: amount, Amount: amount,
Money: payMoney, Money: payMoney,
TradeNo: merchantOrderId, TradeNo: merchantOrderId,
PaymentMethod: "waffo", PaymentMethod: model.PaymentMethodWaffo,
CreateTime: time.Now().Unix(), CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending, Status: common.TopUpStatusPending,
} }
if err := topUp.Insert(); err != nil { if err := topUp.Insert(); err != nil {
log.Printf("Waffo 创建本地订单失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error()))
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return return
} }
sdk, err := getWaffoSDK() sdk, err := getWaffoSDK()
if err != nil { if err != nil {
log.Printf("Waffo SDK 初始化失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo SDK 初始化失败 user_id=%d trade_no=%s error=%q", id, merchantOrderId, err.Error()))
topUp.Status = common.TopUpStatusFailed topUp.Status = common.TopUpStatusFailed
_ = topUp.Update() _ = topUp.Update()
c.JSON(200, gin.H{"message": "error", "data": "支付配置错误"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "支付配置错误"})
return return
} }
@@ -238,29 +268,29 @@ func RequestWaffoPay(c *gin.Context) {
} }
resp, err := sdk.Order().Create(c.Request.Context(), createParams, nil) resp, err := sdk.Order().Create(c.Request.Context(), createParams, nil)
if err != nil { if err != nil {
log.Printf("Waffo 创建订单失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建订单失败 user_id=%d trade_no=%s error=%q", id, merchantOrderId, err.Error()))
topUp.Status = common.TopUpStatusFailed topUp.Status = common.TopUpStatusFailed
_ = topUp.Update() _ = topUp.Update()
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
if !resp.IsSuccess() { if !resp.IsSuccess() {
log.Printf("Waffo 创建订单业务失败: [%s] %s, 完整响应: %+v", resp.Code, resp.Message, resp) logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 创建订单业务失败 user_id=%d trade_no=%s code=%s message=%q response=%q", id, merchantOrderId, resp.Code, resp.Message, common.GetJsonString(resp)))
topUp.Status = common.TopUpStatusFailed topUp.Status = common.TopUpStatusFailed
_ = topUp.Update() _ = topUp.Update()
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
orderData := resp.GetData() orderData := resp.GetData()
log.Printf("Waffo 订单创建成功 - 用户: %d, 订单: %s, 金额: %.2f", id, merchantOrderId, payMoney) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 充值订单创建成功 user_id=%d trade_no=%s amount=%d money=%.2f pay_method_type=%s pay_method_name=%q", id, merchantOrderId, req.Amount, payMoney, resolvedPayMethodType, resolvedPayMethodName))
paymentUrl := orderData.FetchRedirectURL() paymentUrl := orderData.FetchRedirectURL()
if paymentUrl == "" { if paymentUrl == "" {
paymentUrl = orderData.OrderAction paymentUrl = orderData.OrderAction
} }
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "success", "message": "success",
"data": gin.H{ "data": gin.H{
"payment_url": paymentUrl, "payment_url": paymentUrl,
@@ -287,16 +317,22 @@ type webhookSubscriptionInfo struct {
// WaffoWebhook 处理 Waffo 回调通知(支付/退款/订阅) // WaffoWebhook 处理 Waffo 回调通知(支付/退款/订阅)
func WaffoWebhook(c *gin.Context) { func WaffoWebhook(c *gin.Context) {
if !isWaffoWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
bodyBytes, err := io.ReadAll(c.Request.Body) bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
log.Printf("Waffo Webhook 读取 body 失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusBadRequest) c.AbortWithStatus(http.StatusBadRequest)
return return
} }
sdk, err := getWaffoSDK() sdk, err := getWaffoSDK()
if err != nil { if err != nil {
log.Printf("Waffo Webhook SDK 初始化失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook SDK 初始化失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
@@ -304,17 +340,18 @@ func WaffoWebhook(c *gin.Context) {
wh := sdk.Webhook() wh := sdk.Webhook()
bodyStr := string(bodyBytes) bodyStr := string(bodyBytes)
signature := c.GetHeader("X-SIGNATURE") signature := c.GetHeader("X-SIGNATURE")
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, bodyStr))
// 验证请求签名 // 验证请求签名
if !wh.VerifySignature(bodyStr, signature) { if !wh.VerifySignature(bodyStr, signature) {
log.Printf("Waffo webhook 签名验证失败") logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo webhook 验签失败 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, bodyStr))
c.AbortWithStatus(http.StatusBadRequest) c.AbortWithStatus(http.StatusBadRequest)
return return
} }
var event core.WebhookEvent var event core.WebhookEvent
if err := common.Unmarshal(bodyBytes, &event); err != nil { if err := common.Unmarshal(bodyBytes, &event); err != nil {
log.Printf("Waffo Webhook 解析失败: %v", err) logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook 解析失败 path=%q client_ip=%s error=%q body=%q", c.Request.RequestURI, c.ClientIP(), err.Error(), bodyStr))
sendWaffoWebhookResponse(c, wh, false, "invalid payload") sendWaffoWebhookResponse(c, wh, false, "invalid payload")
return return
} }
@@ -324,14 +361,14 @@ func WaffoWebhook(c *gin.Context) {
// 解析为扩展类型,区分普通支付和订阅支付 // 解析为扩展类型,区分普通支付和订阅支付
var payload webhookPayloadWithSubInfo var payload webhookPayloadWithSubInfo
if err := common.Unmarshal(bodyBytes, &payload); err != nil { if err := common.Unmarshal(bodyBytes, &payload); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 支付回调载荷解析失败 event_type=%s client_ip=%s error=%q body=%q", event.EventType, c.ClientIP(), err.Error(), bodyStr))
sendWaffoWebhookResponse(c, wh, false, "invalid payment payload") sendWaffoWebhookResponse(c, wh, false, "invalid payment payload")
return return
} }
log.Printf("Waffo Webhook - EventType: %s, MerchantOrderId: %s, OrderStatus: %s", logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 验签并解析成功 event_type=%s merchant_order_id=%s order_status=%s client_ip=%s", event.EventType, payload.Result.MerchantOrderID, payload.Result.OrderStatus, c.ClientIP()))
event.EventType, payload.Result.MerchantOrderID, payload.Result.OrderStatus)
handleWaffoPayment(c, wh, &payload.Result.PaymentNotificationResult) handleWaffoPayment(c, wh, &payload.Result.PaymentNotificationResult)
default: default:
log.Printf("Waffo Webhook 未知事件: %s", event.EventType) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 忽略事件 event_type=%s client_ip=%s", event.EventType, c.ClientIP()))
sendWaffoWebhookResponse(c, wh, true, "") sendWaffoWebhookResponse(c, wh, true, "")
} }
} }
@@ -339,13 +376,13 @@ func WaffoWebhook(c *gin.Context) {
// handleWaffoPayment 处理支付完成通知 // handleWaffoPayment 处理支付完成通知
func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.PaymentNotificationResult) { func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.PaymentNotificationResult) {
if result.OrderStatus != "PAY_SUCCESS" { if result.OrderStatus != "PAY_SUCCESS" {
log.Printf("Waffo 订单状态非成功: %s, 订单: %s", result.OrderStatus, result.MerchantOrderID) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP()))
// 终态失败订单标记为 failed,避免永远停在 pending // 终态失败订单标记为 failed,避免永远停在 pending
if result.MerchantOrderID != "" { if result.MerchantOrderID != "" {
if topUp := model.GetTopUpByTradeNo(result.MerchantOrderID); topUp != nil && if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentMethodWaffo, common.TopUpStatusFailed); err != nil &&
topUp.Status == common.TopUpStatusPending { !errors.Is(err, model.ErrTopUpNotFound) &&
topUp.Status = common.TopUpStatusFailed !errors.Is(err, model.ErrTopUpStatusInvalid) {
_ = topUp.Update() logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error()))
} }
} }
sendWaffoWebhookResponse(c, wh, true, "") sendWaffoWebhookResponse(c, wh, true, "")
@@ -357,13 +394,13 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa
LockOrder(merchantOrderId) LockOrder(merchantOrderId)
defer UnlockOrder(merchantOrderId) defer UnlockOrder(merchantOrderId)
if err := model.RechargeWaffo(merchantOrderId); err != nil { if err := model.RechargeWaffo(merchantOrderId, c.ClientIP()); err != nil {
log.Printf("Waffo 充值处理失败: %v, 订单: %s", err, merchantOrderId) logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 充值处理失败 trade_no=%s client_ip=%s error=%q", merchantOrderId, c.ClientIP(), err.Error()))
sendWaffoWebhookResponse(c, wh, false, err.Error()) sendWaffoWebhookResponse(c, wh, false, err.Error())
return return
} }
log.Printf("Waffo 充值成功 - 订单: %s", merchantOrderId) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 充值成功 trade_no=%s client_ip=%s", merchantOrderId, c.ClientIP()))
sendWaffoWebhookResponse(c, wh, true, "") sendWaffoWebhookResponse(c, wh, true, "")
} }
+259
View File
@@ -0,0 +1,259 @@
package controller
import (
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"github.com/thanhpk/randstr"
)
type WaffoPancakePayRequest struct {
Amount int64 `json:"amount"`
}
func RequestWaffoPancakeAmount(c *gin.Context) {
var req WaffoPancakePayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < int64(setting.WaffoPancakeMinTopUp) {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", setting.WaffoPancakeMinTopUp)})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getWaffoPancakePayMoney(req.Amount, group)
if payMoney <= 0.01 {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "success", "data": fmt.Sprintf("%.2f", payMoney)})
}
func getWaffoPancakePayMoney(amount int64, group string) float64 {
dAmount := decimal.NewFromInt(amount)
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
dAmount = dAmount.Div(decimal.NewFromFloat(common.QuotaPerUnit))
}
topupGroupRatio := common.GetTopupGroupRatio(group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
discount := 1.0
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok && ds > 0 {
discount = ds
}
payMoney := dAmount.
Mul(decimal.NewFromFloat(setting.WaffoPancakeUnitPrice)).
Mul(decimal.NewFromFloat(topupGroupRatio)).
Mul(decimal.NewFromFloat(discount))
return payMoney.InexactFloat64()
}
func normalizeWaffoPancakeTopUpAmount(amount int64) int64 {
if operation_setting.GetQuotaDisplayType() != operation_setting.QuotaDisplayTypeTokens {
return amount
}
normalized := decimal.NewFromInt(amount).
Div(decimal.NewFromFloat(common.QuotaPerUnit)).
IntPart()
if normalized < 1 {
return 1
}
return normalized
}
func formatWaffoPancakeAmount(payMoney float64) string {
return decimal.NewFromFloat(payMoney).StringFixed(2)
}
func getWaffoPancakeBuyerEmail(user *model.User) string {
if user != nil && strings.TrimSpace(user.Email) != "" {
return user.Email
}
if user != nil {
return fmt.Sprintf("%d@new-api.local", user.Id)
}
return ""
}
func getWaffoPancakeReturnURL() string {
if strings.TrimSpace(setting.WaffoPancakeReturnURL) != "" {
return setting.WaffoPancakeReturnURL
}
return strings.TrimRight(system_setting.ServerAddress, "/") + "/console/topup?show_history=true"
}
func RequestWaffoPancakePay(c *gin.Context) {
if !setting.WaffoPancakeEnabled {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 支付未启用"})
return
}
currentWebhookKey := setting.WaffoPancakeWebhookPublicKey
if setting.WaffoPancakeSandbox {
currentWebhookKey = setting.WaffoPancakeWebhookTestKey
}
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" ||
strings.TrimSpace(currentWebhookKey) == "" ||
strings.TrimSpace(setting.WaffoPancakeStoreID) == "" ||
strings.TrimSpace(setting.WaffoPancakeProductID) == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 配置不完整"})
return
}
var req WaffoPancakePayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < int64(setting.WaffoPancakeMinTopUp) {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", setting.WaffoPancakeMinTopUp)})
return
}
id := c.GetInt("id")
user, err := model.GetUserById(id, false)
if err != nil || user == nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "用户不存在"})
return
}
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getWaffoPancakePayMoney(req.Amount, group)
if payMoney < 0.01 {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
topUp := &model.TopUp{
UserId: id,
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
Money: payMoney,
TradeNo: tradeNo,
PaymentMethod: model.PaymentMethodWaffoPancake,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := topUp.Insert(); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
expiresInSeconds := 45 * 60
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
StoreID: setting.WaffoPancakeStoreID,
ProductID: setting.WaffoPancakeProductID,
ProductType: "onetime",
Currency: strings.ToUpper(strings.TrimSpace(setting.WaffoPancakeCurrency)),
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
Amount: formatWaffoPancakeAmount(payMoney),
TaxIncluded: false,
TaxCategory: "saas",
},
BuyerEmail: getWaffoPancakeBuyerEmail(user),
SuccessURL: getWaffoPancakeReturnURL(),
ExpiresInSeconds: &expiresInSeconds,
})
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建结账会话失败 user_id=%d trade_no=%s error=%q", id, tradeNo, err.Error()))
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值订单创建成功 user_id=%d trade_no=%s session_id=%s amount=%d money=%.2f", id, tradeNo, session.SessionID, req.Amount, payMoney))
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": session.CheckoutURL,
"session_id": session.SessionID,
"expires_at": session.ExpiresAt,
"order_id": tradeNo,
},
})
}
func WaffoPancakeWebhook(c *gin.Context) {
if !isWaffoPancakeWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.String(http.StatusForbidden, "webhook disabled")
return
}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.String(http.StatusBadRequest, "bad request")
return
}
signature := c.GetHeader("X-Waffo-Signature")
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
event, err := service.VerifyConfiguredWaffoPancakeWebhook(string(bodyBytes), signature)
if err != nil {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签失败 path=%q client_ip=%s signature=%q body=%q error=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes), err.Error()))
c.String(http.StatusUnauthorized, "invalid signature")
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签成功 event_type=%s event_id=%s order_id=%s client_ip=%s", event.NormalizedEventType(), event.ID, event.Data.OrderID, c.ClientIP()))
if event.NormalizedEventType() != "order.completed" {
c.String(http.StatusOK, "OK")
return
}
tradeNo, err := service.ResolveWaffoPancakeTradeNo(event)
if err != nil {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 订单号映射失败 event_id=%s order_id=%s error=%q", event.ID, event.Data.OrderID, err.Error()))
c.String(http.StatusOK, "OK")
return
}
LockOrder(tradeNo)
defer UnlockOrder(tradeNo)
if err := model.RechargeWaffoPancake(tradeNo); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值处理失败 trade_no=%s event_id=%s order_id=%s client_ip=%s error=%q", tradeNo, event.ID, event.Data.OrderID, c.ClientIP(), err.Error()))
c.String(http.StatusInternalServerError, "retry")
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值成功 trade_no=%s event_id=%s order_id=%s client_ip=%s", tradeNo, event.ID, event.Data.OrderID, c.ClientIP()))
c.String(http.StatusOK, "OK")
}
+91
View File
@@ -0,0 +1,91 @@
package controller
import (
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/stretchr/testify/require"
)
func TestFormatWaffoPancakeAmount_UsesDisplayPriceString(t *testing.T) {
testCases := []struct {
name string
amount float64
expected string
}{
{name: "whole amount", amount: 29, expected: "29.00"},
{name: "decimal amount", amount: 29.9, expected: "29.90"},
{name: "round half up to cents", amount: 29.999, expected: "30.00"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.expected, formatWaffoPancakeAmount(tc.amount))
})
}
}
func TestGetWaffoPancakePayMoney(t *testing.T) {
originalUnitPrice := setting.WaffoPancakeUnitPrice
originalQuotaDisplayType := operation_setting.GetGeneralSetting().QuotaDisplayType
originalDiscounts := make(map[int]float64, len(operation_setting.GetPaymentSetting().AmountDiscount))
for k, v := range operation_setting.GetPaymentSetting().AmountDiscount {
originalDiscounts[k] = v
}
originalTopupGroupRatio := common.TopupGroupRatio2JSONString()
t.Cleanup(func() {
setting.WaffoPancakeUnitPrice = originalUnitPrice
operation_setting.GetGeneralSetting().QuotaDisplayType = originalQuotaDisplayType
operation_setting.GetPaymentSetting().AmountDiscount = originalDiscounts
require.NoError(t, common.UpdateTopupGroupRatioByJSONString(originalTopupGroupRatio))
})
setting.WaffoPancakeUnitPrice = 2.5
operation_setting.GetPaymentSetting().AmountDiscount = map[int]float64{
10: 0.8,
int(common.QuotaPerUnit * 3): 0.5,
20: 0,
}
require.NoError(t, common.UpdateTopupGroupRatioByJSONString(`{"default":1,"vip":1.2}`))
testCases := []struct {
name string
amount int64
group string
quotaDisplayType string
expected float64
}{
{
name: "currency display applies unit price group ratio and discount",
amount: 10,
group: "vip",
quotaDisplayType: operation_setting.QuotaDisplayTypeUSD,
expected: 24,
},
{
name: "tokens display converts quota to display units before pricing",
amount: int64(common.QuotaPerUnit * 3),
group: "vip",
quotaDisplayType: operation_setting.QuotaDisplayTypeTokens,
expected: 4.5,
},
{
name: "non-positive discount falls back to no discount",
amount: 20,
group: "default",
quotaDisplayType: operation_setting.QuotaDisplayTypeUSD,
expected: 50,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
operation_setting.GetGeneralSetting().QuotaDisplayType = tc.quotaDisplayType
actual := getWaffoPancakePayMoney(tc.amount, tc.group)
require.InDelta(t, tc.expected, actual, 0.000001)
})
}
}
+8 -4
View File
@@ -2,7 +2,6 @@ package controller
import ( import (
"errors" "errors"
"fmt"
"net/http" "net/http"
"strconv" "strconv"
@@ -542,10 +541,15 @@ func AdminDisable2FA(c *gin.Context) {
return return
} }
// 记录操作日志 // 记录操作日志:管理员身份通过 admin_info 传递,避免在非管理员可见的日志内容中泄露。
adminId := c.GetInt("id") adminId := c.GetInt("id")
model.RecordLog(userId, model.LogTypeManage, adminName := c.GetString("username")
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId)) adminInfo := map[string]interface{}{
"admin_id": adminId,
"admin_username": adminName,
}
model.RecordLogWithAdminInfo(userId, model.LogTypeManage,
"管理员强制禁用了用户的两步验证", adminInfo)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
+75 -7
View File
@@ -52,10 +52,15 @@ func Login(c *gin.Context) {
} }
err = user.ValidateAndFill() err = user.ValidateAndFill()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ switch {
"message": err.Error(), case errors.Is(err, model.ErrDatabase):
"success": false, common.SysLog(fmt.Sprintf("Login database error for user %s: %v", username, err))
}) common.ApiErrorI18n(c, i18n.MsgDatabaseError)
case errors.Is(err, model.ErrUserEmptyCredentials):
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
default:
common.ApiErrorI18n(c, i18n.MsgUserUsernameOrPasswordError)
}
return return
} }
@@ -572,9 +577,6 @@ func UpdateUser(c *gin.Context) {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
if originUser.Quota != updatedUser.Quota {
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@@ -841,6 +843,8 @@ func CreateUser(c *gin.Context) {
type ManageRequest struct { type ManageRequest struct {
Id int `json:"id"` Id int `json:"id"`
Action string `json:"action"` Action string `json:"action"`
Value int `json:"value"`
Mode string `json:"mode"`
} }
// ManageUser Only admin user can do this // ManageUser Only admin user can do this
@@ -887,6 +891,11 @@ func ManageUser(c *gin.Context) {
}) })
return return
} }
// 删除用户后,强制清理 Redis 中所有该用户令牌的缓存,
// 避免已缓存的令牌在 TTL 过期前仍能通过 TokenAuth 校验。
if err := model.InvalidateUserTokensCache(user.Id); err != nil {
common.SysLog(fmt.Sprintf("failed to invalidate tokens cache for user %d: %s", user.Id, err.Error()))
}
case "promote": case "promote":
if myRole != common.RoleRootUser { if myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserAdminCannotPromote) common.ApiErrorI18n(c, i18n.MsgUserAdminCannotPromote)
@@ -907,12 +916,71 @@ func ManageUser(c *gin.Context) {
return return
} }
user.Role = common.RoleCommonUser user.Role = common.RoleCommonUser
case "add_quota":
adminName := c.GetString("username")
adminId := c.GetInt("id")
adminInfo := map[string]interface{}{
"admin_id": adminId,
"admin_username": adminName,
}
switch req.Mode {
case "add":
if req.Value <= 0 {
common.ApiErrorI18n(c, i18n.MsgUserQuotaChangeZero)
return
}
if err := model.IncreaseUserQuota(user.Id, req.Value, true); err != nil {
common.ApiError(c, err)
return
}
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员增加用户额度 %s", logger.LogQuota(req.Value)), adminInfo)
case "subtract":
if req.Value <= 0 {
common.ApiErrorI18n(c, i18n.MsgUserQuotaChangeZero)
return
}
if err := model.DecreaseUserQuota(user.Id, req.Value, true); err != nil {
common.ApiError(c, err)
return
}
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员减少用户额度 %s", logger.LogQuota(req.Value)), adminInfo)
case "override":
oldQuota := user.Quota
if err := model.DB.Model(&model.User{}).Where("id = ?", user.Id).Update("quota", req.Value).Error; err != nil {
common.ApiError(c, err)
return
}
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员覆盖用户额度从 %s 为 %s", logger.LogQuota(oldQuota), logger.LogQuota(req.Value)), adminInfo)
default:
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
} }
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
// 禁用 / 角色调整后,强制失效用户缓存与其全部令牌缓存,
// 避免在 Redis TTL 过期前仍使用旧状态(尤其是禁用后仍可发起请求的问题)。
// InvalidateUserCache 会让下一次 GetUserCache 从数据库重新加载,
// InvalidateUserTokensCache 则确保令牌侧的缓存也同步刷新。
if req.Action == "disable" || req.Action == "promote" || req.Action == "demote" {
if err := model.InvalidateUserCache(user.Id); err != nil {
common.SysLog(fmt.Sprintf("failed to invalidate user cache for user %d: %s", user.Id, err.Error()))
}
if err := model.InvalidateUserTokensCache(user.Id); err != nil {
common.SysLog(fmt.Sprintf("failed to invalidate tokens cache for user %d: %s", user.Id, err.Error()))
}
}
clearUser := model.User{ clearUser := model.User{
Role: user.Role, Role: user.Role,
Status: user.Status, Status: user.Status,
+3 -1
View File
@@ -28,10 +28,11 @@ services:
environment: environment:
- SQL_DSN=postgresql://root:123456@postgres:5432/new-api # ⚠️ IMPORTANT: Change the password in production! - SQL_DSN=postgresql://root:123456@postgres:5432/new-api # ⚠️ IMPORTANT: Change the password in production!
# - SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service, uncomment if using MySQL # - SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service, uncomment if using MySQL
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://:123456@redis:6379 # ⚠️ IMPORTANT: Change the password in production!
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录 (Whether to enable error log recording) - ERROR_LOG_ENABLED=true # 是否启用错误日志记录 (Whether to enable error log recording)
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 (Whether to enable batch update) - BATCH_UPDATE_ENABLED=true # 是否启用批量更新 (Whether to enable batch update)
- NODE_NAME=new-api-node-1 # 节点名称,用于审计日志中标识节点身份;多节点/容器部署时建议设置 (Node name used in audit logs; recommended when running multiple instances or in containers)
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 (Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions # - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 (Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! (multi-node deployment, set this to a random string!!!!!!! # - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! (multi-node deployment, set this to a random string!!!!!!!
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed # - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
@@ -55,6 +56,7 @@ services:
image: redis:latest image: redis:latest
container_name: redis container_name: redis
restart: always restart: always
command: ["redis-server", "--requirepass", "123456"] # ⚠️ IMPORTANT: Change this password in production!
networks: networks:
- new-api-network - new-api-network
+53 -1
View File
@@ -3281,6 +3281,13 @@
} }
] ]
}, },
"cache_control": {
"type": "object",
"properties": {}
},
"inference_geo": {
"type": "string"
},
"max_tokens": { "max_tokens": {
"type": "integer", "type": "integer",
"minimum": 1 "minimum": 1
@@ -3333,7 +3340,8 @@
"enum": [ "enum": [
"auto", "auto",
"any", "any",
"tool" "tool",
"none"
] ]
}, },
"name": { "name": {
@@ -3358,6 +3366,36 @@
} }
} }
}, },
"context_management": {
"type": "object",
"properties": {}
},
"output_config": {
"type": "object",
"properties": {}
},
"output_format": {
"type": "object",
"properties": {}
},
"container": {
"oneOf": [
{
"type": "string"
},
{
"type": "object",
"properties": {}
}
]
},
"mcp_servers": {
"type": "array",
"items": {
"type": "object",
"properties": {}
}
},
"metadata": { "metadata": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -3365,6 +3403,20 @@
"type": "string" "type": "string"
} }
} }
},
"speed": {
"type": "string",
"enum": [
"standard",
"fast"
]
},
"service_tier": {
"type": "string",
"enum": [
"auto",
"standard_only"
]
} }
} }
}, },
+1
View File
@@ -30,6 +30,7 @@ type ChannelOtherSettings struct {
ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true
AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费) AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
AllowInferenceGeo bool `json:"allow_inference_geo,omitempty"` // 是否允许 inference_geo 透传(仅 Claude,默认过滤以满足数据驻留合规 AllowInferenceGeo bool `json:"allow_inference_geo,omitempty"` // 是否允许 inference_geo 透传(仅 Claude,默认过滤以满足数据驻留合规
AllowSpeed bool `json:"allow_speed,omitempty"` // 是否允许 speed 透传(仅 Claude,默认过滤以避免意外切换推理速度模式)
AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私) AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用) DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
AllowIncludeObfuscation bool `json:"allow_include_obfuscation,omitempty"` // 是否允许 stream_options.include_obfuscation 透传(默认过滤以避免关闭流混淆保护) AllowIncludeObfuscation bool `json:"allow_include_obfuscation,omitempty"` // 是否允许 stream_options.include_obfuscation 透传(默认过滤以避免关闭流混淆保护)
+13 -4
View File
@@ -204,10 +204,11 @@ type ClaudeToolChoice struct {
} }
type ClaudeRequest struct { type ClaudeRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
System any `json:"system,omitempty"` System any `json:"system,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"` Messages []ClaudeMessage `json:"messages,omitempty"`
CacheControl json.RawMessage `json:"cache_control,omitempty"`
// InferenceGeo controls Claude data residency region. // InferenceGeo controls Claude data residency region.
// This field is filtered by default and can be enabled via channel setting allow_inference_geo. // This field is filtered by default and can be enabled via channel setting allow_inference_geo.
InferenceGeo string `json:"inference_geo,omitempty"` InferenceGeo string `json:"inference_geo,omitempty"`
@@ -227,6 +228,9 @@ type ClaudeRequest struct {
Thinking *Thinking `json:"thinking,omitempty"` Thinking *Thinking `json:"thinking,omitempty"`
McpServers json.RawMessage `json:"mcp_servers,omitempty"` McpServers json.RawMessage `json:"mcp_servers,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"` Metadata json.RawMessage `json:"metadata,omitempty"`
// Speed specifies the Claude inference speed mode.
// This field is filtered by default and can be enabled via channel setting allow_speed.
Speed json.RawMessage `json:"speed,omitempty"`
// ServiceTier specifies upstream service level and may affect billing. // ServiceTier specifies upstream service level and may affect billing.
// This field is filtered by default and can be enabled via channel setting allow_service_tier. // This field is filtered by default and can be enabled via channel setting allow_service_tier.
ServiceTier string `json:"service_tier,omitempty"` ServiceTier string `json:"service_tier,omitempty"`
@@ -444,6 +448,11 @@ func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
type Thinking struct { type Thinking struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
BudgetTokens *int `json:"budget_tokens,omitempty"` BudgetTokens *int `json:"budget_tokens,omitempty"`
// Display controls whether thinking content is returned in the response.
// Used with adaptive thinking on Claude Opus 4.7+: "summarized" restores
// the visible summary that was default on Opus 4.6; "omitted" (default on
// 4.7) suppresses it. Pass-through field from upstream Anthropic API.
Display string `json:"display,omitempty"`
} }
func (c *Thinking) GetBudgetTokens() int { func (c *Thinking) GetBudgetTokens() int {
+1
View File
@@ -46,6 +46,7 @@ func (r *GeminiChatRequest) UnmarshalJSON(data []byte) error {
type ToolConfig struct { type ToolConfig struct {
FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"` FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"`
RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"` RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"`
IncludeServerSideToolInvocations *bool `json:"includeServerSideToolInvocations,omitempty"`
} }
type FunctionCallingConfig struct { type FunctionCallingConfig struct {
+1 -1
View File
@@ -273,7 +273,7 @@ type OpenAIResponsesResponse struct {
Status json.RawMessage `json:"status"` Status json.RawMessage `json:"status"`
Error any `json:"error,omitempty"` Error any `json:"error,omitempty"`
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
Instructions string `json:"instructions"` Instructions json.RawMessage `json:"instructions"`
MaxOutputTokens int `json:"max_output_tokens"` MaxOutputTokens int `json:"max_output_tokens"`
Model string `json:"model"` Model string `json:"model"`
Output []ResponsesOutput `json:"output"` Output []ResponsesOutput `json:"output"`
+22
View File
@@ -5,6 +5,28 @@ import (
"strconv" "strconv"
) )
type StringValue string
func (s *StringValue) UnmarshalJSON(data []byte) error {
var str string
if err := json.Unmarshal(data, &str); err == nil {
*s = StringValue(str)
return nil
}
var raw json.Number
if err := json.Unmarshal(data, &raw); err == nil {
*s = StringValue(raw.String())
return nil
}
return json.Unmarshal(data, &str)
}
func (s StringValue) MarshalJSON() ([]byte, error) {
return json.Marshal(string(s))
}
type IntValue int type IntValue int
func (i *IntValue) UnmarshalJSON(b []byte) error { func (i *IntValue) UnmarshalJSON(b []byte) error {
Generated Vendored
+3 -3
View File
@@ -777,9 +777,9 @@
} }
}, },
"node_modules/@xmldom/xmldom": { "node_modules/@xmldom/xmldom": {
"version": "0.8.12", "version": "0.8.13",
"resolved": "https://registry.npmjs.org/@xmldom/xmldom/-/xmldom-0.8.12.tgz", "resolved": "https://registry.npmjs.org/@xmldom/xmldom/-/xmldom-0.8.13.tgz",
"integrity": "sha512-9k/gHF6n/pAi/9tqr3m3aqkuiNosYTurLLUtc7xQ9sxB/wm7WPygCv8GYa6mS0fLJEHhqMC1ATYhz++U/lRHqg==", "integrity": "sha512-KRYzxepc14G/CEpEGc3Yn+JKaAeT63smlDr+vjB8jRfgTBBI9wRj/nkQEO+ucV8p8I9bfKLWp37uHgFrbntPvw==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"engines": { "engines": {
+1 -1
View File
@@ -97,7 +97,7 @@ require (
github.com/icza/bitio v1.1.0 // indirect github.com/icza/bitio v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.1 // indirect github.com/jackc/pgx/v5 v5.9.2 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jfreymuth/vorbis v1.0.2 // indirect github.com/jfreymuth/vorbis v1.0.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
+2 -2
View File
@@ -154,8 +154,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jfreymuth/oggvorbis v1.0.5 h1:u+Ck+R0eLSRhgq8WTmffYnrVtSztJcYrl588DM4e3kQ= github.com/jfreymuth/oggvorbis v1.0.5 h1:u+Ck+R0eLSRhgq8WTmffYnrVtSztJcYrl588DM4e3kQ=
+13
View File
@@ -28,6 +28,18 @@ const (
MsgBatchTooMany = "common.batch_too_many" MsgBatchTooMany = "common.batch_too_many"
) )
// Auth middleware messages
const (
MsgAuthNotLoggedIn = "auth.not_logged_in"
MsgAuthAccessTokenInvalid = "auth.access_token_invalid"
MsgAuthUserInfoInvalid = "auth.user_info_invalid"
MsgAuthUserIdNotProvided = "auth.user_id_not_provided"
MsgAuthUserIdFormatError = "auth.user_id_format_error"
MsgAuthUserIdMismatch = "auth.user_id_mismatch"
MsgAuthUserBanned = "auth.user_banned"
MsgAuthInsufficientPrivilege = "auth.insufficient_privilege"
)
// Token related messages // Token related messages
const ( const (
MsgTokenNameTooLong = "token.name_too_long" MsgTokenNameTooLong = "token.name_too_long"
@@ -101,6 +113,7 @@ const (
MsgUserTelegramIdEmpty = "user.telegram_id_empty" MsgUserTelegramIdEmpty = "user.telegram_id_empty"
MsgUserTelegramNotBound = "user.telegram_not_bound" MsgUserTelegramNotBound = "user.telegram_not_bound"
MsgUserLinuxDOIdEmpty = "user.linux_do_id_empty" MsgUserLinuxDOIdEmpty = "user.linux_do_id_empty"
MsgUserQuotaChangeZero = "user.quota_change_zero"
) )
// Quota related messages // Quota related messages
+12 -1
View File
@@ -2,7 +2,7 @@
# Common messages # Common messages
common.invalid_params: "Invalid parameters" common.invalid_params: "Invalid parameters"
common.database_error: "Database error, please try again later" common.database_error: "Database error, please contact the administrator"
common.retry_later: "Please try again later" common.retry_later: "Please try again later"
common.generate_failed: "Generation failed" common.generate_failed: "Generation failed"
common.not_found: "Not found" common.not_found: "Not found"
@@ -23,6 +23,16 @@ common.already_exists: "Already exists"
common.name_cannot_be_empty: "Name cannot be empty" common.name_cannot_be_empty: "Name cannot be empty"
common.batch_too_many: "Too many items in batch request, maximum is {{.Max}}" common.batch_too_many: "Too many items in batch request, maximum is {{.Max}}"
# Auth middleware messages
auth.not_logged_in: "Unauthorized, not logged in and no access token provided"
auth.access_token_invalid: "Unauthorized, invalid access token"
auth.user_info_invalid: "Unauthorized, invalid user info"
auth.user_id_not_provided: "Unauthorized, New-Api-User header not provided"
auth.user_id_format_error: "Unauthorized, New-Api-User header format error"
auth.user_id_mismatch: "Unauthorized, New-Api-User does not match logged in user"
auth.user_banned: "User has been banned"
auth.insufficient_privilege: "Unauthorized, insufficient privileges"
# Token messages # Token messages
token.name_too_long: "Token name is too long" token.name_too_long: "Token name is too long"
token.quota_negative: "Quota value cannot be negative" token.quota_negative: "Quota value cannot be negative"
@@ -91,6 +101,7 @@ user.wechat_id_empty: "WeChat ID is empty!"
user.telegram_id_empty: "Telegram ID is empty!" user.telegram_id_empty: "Telegram ID is empty!"
user.telegram_not_bound: "This Telegram account is not bound" user.telegram_not_bound: "This Telegram account is not bound"
user.linux_do_id_empty: "Linux DO ID is empty!" user.linux_do_id_empty: "Linux DO ID is empty!"
user.quota_change_zero: "Quota change amount cannot be zero"
# Quota messages # Quota messages
quota.negative: "Quota cannot be negative!" quota.negative: "Quota cannot be negative!"
+12 -1
View File
@@ -3,7 +3,7 @@
# Common messages # Common messages
common.invalid_params: "无效的参数" common.invalid_params: "无效的参数"
common.database_error: "数据库错误,请稍后重试" common.database_error: "数据库出错,请联系管理员"
common.retry_later: "请稍后重试" common.retry_later: "请稍后重试"
common.generate_failed: "生成失败" common.generate_failed: "生成失败"
common.not_found: "未找到" common.not_found: "未找到"
@@ -24,6 +24,16 @@ common.already_exists: "已存在"
common.name_cannot_be_empty: "名称不能为空" common.name_cannot_be_empty: "名称不能为空"
common.batch_too_many: "批量请求数量过多,最多 {{.Max}} 条" common.batch_too_many: "批量请求数量过多,最多 {{.Max}} 条"
# Auth middleware messages
auth.not_logged_in: "无权进行此操作,未登录且未提供 access token"
auth.access_token_invalid: "无权进行此操作,access token 无效"
auth.user_info_invalid: "无权进行此操作,用户信息无效"
auth.user_id_not_provided: "无权进行此操作,未提供 New-Api-User"
auth.user_id_format_error: "无权进行此操作,New-Api-User 格式错误"
auth.user_id_mismatch: "无权进行此操作,New-Api-User 与登录用户不匹配"
auth.user_banned: "用户已被封禁"
auth.insufficient_privilege: "无权进行此操作,权限不足"
# Token messages # Token messages
token.name_too_long: "令牌名称过长" token.name_too_long: "令牌名称过长"
token.quota_negative: "额度值不能为负数" token.quota_negative: "额度值不能为负数"
@@ -92,6 +102,7 @@ user.wechat_id_empty: "WeChat id 为空!"
user.telegram_id_empty: "Telegram id 为空!" user.telegram_id_empty: "Telegram id 为空!"
user.telegram_not_bound: "该 Telegram 账户未绑定" user.telegram_not_bound: "该 Telegram 账户未绑定"
user.linux_do_id_empty: "Linux DO id 为空!" user.linux_do_id_empty: "Linux DO id 为空!"
user.quota_change_zero: "额度变更量不能为0"
# Quota messages # Quota messages
quota.negative: "额度不能为负数!" quota.negative: "额度不能为负数!"
+12 -1
View File
@@ -3,7 +3,7 @@
# Common messages # Common messages
common.invalid_params: "無效的參數" common.invalid_params: "無效的參數"
common.database_error: "資料庫錯誤,請稍後重試" common.database_error: "資料庫出錯,請聯繫管理員"
common.retry_later: "請稍後重試" common.retry_later: "請稍後重試"
common.generate_failed: "生成失敗" common.generate_failed: "生成失敗"
common.not_found: "未找到" common.not_found: "未找到"
@@ -24,6 +24,16 @@ common.already_exists: "已存在"
common.name_cannot_be_empty: "名稱不能為空" common.name_cannot_be_empty: "名稱不能為空"
common.batch_too_many: "批次請求數量過多,最多 {{.Max}} 條" common.batch_too_many: "批次請求數量過多,最多 {{.Max}} 條"
# Auth middleware messages
auth.not_logged_in: "無權進行此操作,未登入且未提供 access token"
auth.access_token_invalid: "無權進行此操作,access token 無效"
auth.user_info_invalid: "無權進行此操作,使用者資訊無效"
auth.user_id_not_provided: "無權進行此操作,未提供 New-Api-User"
auth.user_id_format_error: "無權進行此操作,New-Api-User 格式錯誤"
auth.user_id_mismatch: "無權進行此操作,New-Api-User 與登入使用者不匹配"
auth.user_banned: "使用者已被封禁"
auth.insufficient_privilege: "無權進行此操作,權限不足"
# Token messages # Token messages
token.name_too_long: "令牌名稱過長" token.name_too_long: "令牌名稱過長"
token.quota_negative: "額度值不能為負數" token.quota_negative: "額度值不能為負數"
@@ -92,6 +102,7 @@ user.wechat_id_empty: "WeChat id 為空!"
user.telegram_id_empty: "Telegram id 為空!" user.telegram_id_empty: "Telegram id 為空!"
user.telegram_not_bound: "該 Telegram 帳號未綁定" user.telegram_not_bound: "該 Telegram 帳號未綁定"
user.linux_do_id_empty: "Linux DO id 為空!" user.linux_do_id_empty: "Linux DO id 為空!"
user.quota_change_zero: "額度變更量不能為0"
# Quota messages # Quota messages
quota.negative: "額度不能為負數!" quota.negative: "額度不能為負數!"
+57 -20
View File
@@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -9,6 +10,7 @@ import (
"github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/service"
@@ -17,6 +19,7 @@ import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm"
) )
func validUserInfo(username string, role int) bool { func validUserInfo(username string, role int) bool {
@@ -43,17 +46,33 @@ func authHelper(c *gin.Context, minRole int) {
if accessToken == "" { if accessToken == "" {
c.JSON(http.StatusUnauthorized, gin.H{ c.JSON(http.StatusUnauthorized, gin.H{
"success": false, "success": false,
"message": "无权进行此操作,未登录且未提供 access token", "message": common.TranslateMessage(c, i18n.MsgAuthNotLoggedIn),
}) })
c.Abort() c.Abort()
return return
} }
user := model.ValidateAccessToken(accessToken) user, authErr := model.ValidateAccessToken(accessToken)
if authErr != nil {
if errors.Is(authErr, model.ErrDatabase) {
common.SysLog("ValidateAccessToken database error: " + authErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgDatabaseError),
})
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthAccessTokenInvalid),
})
}
c.Abort()
return
}
if user != nil && user.Username != "" { if user != nil && user.Username != "" {
if !validUserInfo(user.Username, user.Role) { if !validUserInfo(user.Username, user.Role) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权进行此操作,用户信息无效", "message": common.TranslateMessage(c, i18n.MsgAuthUserInfoInvalid),
}) })
c.Abort() c.Abort()
return return
@@ -67,7 +86,7 @@ func authHelper(c *gin.Context, minRole int) {
} else { } else {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权进行此操作,access token 无效", "message": common.TranslateMessage(c, i18n.MsgAuthAccessTokenInvalid),
}) })
c.Abort() c.Abort()
return return
@@ -78,7 +97,7 @@ func authHelper(c *gin.Context, minRole int) {
if apiUserIdStr == "" { if apiUserIdStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{ c.JSON(http.StatusUnauthorized, gin.H{
"success": false, "success": false,
"message": "无权进行此操作,未提供 New-Api-User", "message": common.TranslateMessage(c, i18n.MsgAuthUserIdNotProvided),
}) })
c.Abort() c.Abort()
return return
@@ -87,7 +106,7 @@ func authHelper(c *gin.Context, minRole int) {
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{ c.JSON(http.StatusUnauthorized, gin.H{
"success": false, "success": false,
"message": "无权进行此操作,New-Api-User 格式错误", "message": common.TranslateMessage(c, i18n.MsgAuthUserIdFormatError),
}) })
c.Abort() c.Abort()
return return
@@ -96,7 +115,7 @@ func authHelper(c *gin.Context, minRole int) {
if id != apiUserId { if id != apiUserId {
c.JSON(http.StatusUnauthorized, gin.H{ c.JSON(http.StatusUnauthorized, gin.H{
"success": false, "success": false,
"message": "无权进行此操作,New-Api-User 与登录用户不匹配", "message": common.TranslateMessage(c, i18n.MsgAuthUserIdMismatch),
}) })
c.Abort() c.Abort()
return return
@@ -104,7 +123,7 @@ func authHelper(c *gin.Context, minRole int) {
if status.(int) == common.UserStatusDisabled { if status.(int) == common.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "用户已被封禁", "message": common.TranslateMessage(c, i18n.MsgAuthUserBanned),
}) })
c.Abort() c.Abort()
return return
@@ -112,7 +131,7 @@ func authHelper(c *gin.Context, minRole int) {
if role.(int) < minRole { if role.(int) < minRole {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权进行此操作,权限不足", "message": common.TranslateMessage(c, i18n.MsgAuthInsufficientPrivilege),
}) })
c.Abort() c.Abort()
return return
@@ -120,7 +139,7 @@ func authHelper(c *gin.Context, minRole int) {
if !validUserInfo(username.(string), role.(int)) { if !validUserInfo(username.(string), role.(int)) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权进行此操作,用户信息无效", "message": common.TranslateMessage(c, i18n.MsgAuthUserInfoInvalid),
}) })
c.Abort() c.Abort()
return return
@@ -198,7 +217,7 @@ func TokenAuthReadOnly() func(c *gin.Context) {
if key == "" { if key == "" {
c.JSON(http.StatusUnauthorized, gin.H{ c.JSON(http.StatusUnauthorized, gin.H{
"success": false, "success": false,
"message": "未提供 Authorization 请求头", "message": common.TranslateMessage(c, i18n.MsgTokenNotProvided),
}) })
c.Abort() c.Abort()
return return
@@ -212,19 +231,28 @@ func TokenAuthReadOnly() func(c *gin.Context) {
token, err := model.GetTokenByKey(key, false) token, err := model.GetTokenByKey(key, false)
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{ if errors.Is(err, gorm.ErrRecordNotFound) {
"success": false, c.JSON(http.StatusUnauthorized, gin.H{
"message": "无效的令牌", "success": false,
}) "message": common.TranslateMessage(c, i18n.MsgTokenInvalid),
})
} else {
common.SysLog("TokenAuthReadOnly GetTokenByKey database error: " + err.Error())
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgDatabaseError),
})
}
c.Abort() c.Abort()
return return
} }
userCache, err := model.GetUserCache(token.UserId) userCache, err := model.GetUserCache(token.UserId)
if err != nil { if err != nil {
common.SysLog(fmt.Sprintf("TokenAuthReadOnly GetUserCache error for user %d: %v", token.UserId, err))
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": common.TranslateMessage(c, i18n.MsgDatabaseError),
}) })
c.Abort() c.Abort()
return return
@@ -232,7 +260,7 @@ func TokenAuthReadOnly() func(c *gin.Context) {
if userCache.Status != common.UserStatusEnabled { if userCache.Status != common.UserStatusEnabled {
c.JSON(http.StatusForbidden, gin.H{ c.JSON(http.StatusForbidden, gin.H{
"success": false, "success": false,
"message": "用户已被封禁", "message": common.TranslateMessage(c, i18n.MsgAuthUserBanned),
}) })
c.Abort() c.Abort()
return return
@@ -309,7 +337,14 @@ func TokenAuth() func(c *gin.Context) {
} }
} }
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) if errors.Is(err, model.ErrDatabase) {
common.SysLog("TokenAuth ValidateUserToken database error: " + err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError,
common.TranslateMessage(c, i18n.MsgDatabaseError))
} else {
abortWithOpenAiMessage(c, http.StatusUnauthorized,
common.TranslateMessage(c, i18n.MsgTokenInvalid))
}
return return
} }
@@ -331,12 +366,14 @@ func TokenAuth() func(c *gin.Context) {
userCache, err := model.GetUserCache(token.UserId) userCache, err := model.GetUserCache(token.UserId)
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) common.SysLog(fmt.Sprintf("TokenAuth GetUserCache error for user %d: %v", token.UserId, err))
abortWithOpenAiMessage(c, http.StatusInternalServerError,
common.TranslateMessage(c, i18n.MsgDatabaseError))
return return
} }
userEnabled := userCache.Status == common.UserStatusEnabled userEnabled := userCache.Status == common.UserStatusEnabled
if !userEnabled { if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") abortWithOpenAiMessage(c, http.StatusForbidden, common.TranslateMessage(c, i18n.MsgAuthUserBanned))
return return
} }
+12 -10
View File
@@ -10,7 +10,8 @@ import (
const ( const (
// SecureVerificationSessionKey 安全验证的 session key(与 controller 保持一致) // SecureVerificationSessionKey 安全验证的 session key(与 controller 保持一致)
SecureVerificationSessionKey = "secure_verified_at" SecureVerificationSessionKey = "secure_verified_at"
secureVerificationMethodSessionKey = "secure_verified_method"
// SecureVerificationTimeout 验证有效期(秒) // SecureVerificationTimeout 验证有效期(秒)
SecureVerificationTimeout = 300 // 5分钟 SecureVerificationTimeout = 300 // 5分钟
) )
@@ -48,8 +49,7 @@ func SecureVerificationRequired() gin.HandlerFunc {
verifiedAt, ok := verifiedAtRaw.(int64) verifiedAt, ok := verifiedAtRaw.(int64)
if !ok { if !ok {
// session 数据格式错误 // session 数据格式错误
session.Delete(SecureVerificationSessionKey) clearSecureVerificationSession(session)
_ = session.Save()
c.JSON(http.StatusForbidden, gin.H{ c.JSON(http.StatusForbidden, gin.H{
"success": false, "success": false,
"message": "验证状态异常,请重新验证", "message": "验证状态异常,请重新验证",
@@ -63,8 +63,7 @@ func SecureVerificationRequired() gin.HandlerFunc {
elapsed := time.Now().Unix() - verifiedAt elapsed := time.Now().Unix() - verifiedAt
if elapsed >= SecureVerificationTimeout { if elapsed >= SecureVerificationTimeout {
// 验证已过期,清除 session // 验证已过期,清除 session
session.Delete(SecureVerificationSessionKey) clearSecureVerificationSession(session)
_ = session.Save()
c.JSON(http.StatusForbidden, gin.H{ c.JSON(http.StatusForbidden, gin.H{
"success": false, "success": false,
"message": "验证已过期,请重新验证", "message": "验证已过期,请重新验证",
@@ -74,11 +73,16 @@ func SecureVerificationRequired() gin.HandlerFunc {
return return
} }
// 验证有效,继续处理请求
c.Next() c.Next()
} }
} }
func clearSecureVerificationSession(session sessions.Session) {
session.Delete(SecureVerificationSessionKey)
session.Delete(secureVerificationMethodSessionKey)
_ = session.Save()
}
// OptionalSecureVerification 可选的安全验证中间件 // OptionalSecureVerification 可选的安全验证中间件
// 如果用户已验证,则在 context 中设置标记,但不阻止请求继续 // 如果用户已验证,则在 context 中设置标记,但不阻止请求继续
// 用于某些需要区分是否已验证的场景 // 用于某些需要区分是否已验证的场景
@@ -109,8 +113,7 @@ func OptionalSecureVerification() gin.HandlerFunc {
elapsed := time.Now().Unix() - verifiedAt elapsed := time.Now().Unix() - verifiedAt
if elapsed >= SecureVerificationTimeout { if elapsed >= SecureVerificationTimeout {
session.Delete(SecureVerificationSessionKey) clearSecureVerificationSession(session)
_ = session.Save()
c.Set("secure_verified", false) c.Set("secure_verified", false)
c.Next() c.Next()
return return
@@ -126,6 +129,5 @@ func OptionalSecureVerification() gin.HandlerFunc {
// 用于用户登出或需要强制重新验证的场景 // 用于用户登出或需要强制重新验证的场景
func ClearSecureVerification(c *gin.Context) { func ClearSecureVerification(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
session.Delete(SecureVerificationSessionKey) clearSecureVerificationSession(session)
_ = session.Save()
} }
+26
View File
@@ -0,0 +1,26 @@
package model
import "errors"
// Common errors
var (
ErrDatabase = errors.New("database error")
)
// User auth errors
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrUserEmptyCredentials = errors.New("empty credentials")
)
// Token auth errors
var (
ErrTokenNotProvided = errors.New("token not provided")
ErrTokenInvalid = errors.New("token invalid")
)
// Redemption errors
var ErrRedeemFailed = errors.New("redeem.failed")
// 2FA errors
var ErrTwoFANotEnabled = errors.New("2fa not enabled")
+52
View File
@@ -90,6 +90,58 @@ func RecordLog(userId int, logType int, content string) {
} }
} }
// RecordLogWithAdminInfo 记录操作日志,并将管理员相关信息存入 Other.admin_info
func RecordLogWithAdminInfo(userId int, logType int, content string, adminInfo map[string]interface{}) {
if logType == LogTypeConsume && !common.LogConsumeEnabled {
return
}
username, _ := GetUsernameById(userId, false)
log := &Log{
UserId: userId,
Username: username,
CreatedAt: common.GetTimestamp(),
Type: logType,
Content: content,
}
if len(adminInfo) > 0 {
other := map[string]interface{}{
"admin_info": adminInfo,
}
log.Other = common.MapToJsonStr(other)
}
if err := LOG_DB.Create(log).Error; err != nil {
common.SysLog("failed to record log: " + err.Error())
}
}
func RecordTopupLog(userId int, content string, callerIp string, paymentMethod string, callbackPaymentMethod string) {
username, _ := GetUsernameById(userId, false)
adminInfo := map[string]interface{}{
"server_ip": common.GetIp(),
"node_name": common.NodeName,
"caller_ip": callerIp,
"payment_method": paymentMethod,
"callback_payment_method": callbackPaymentMethod,
"version": common.Version,
}
other := map[string]interface{}{
"admin_info": adminInfo,
}
log := &Log{
UserId: userId,
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeTopup,
Content: content,
Ip: callerIp,
Other: common.MapToJsonStr(other),
}
err := LOG_DB.Create(log).Error
if err != nil {
common.SysLog("failed to record topup log: " + err.Error())
}
}
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int, func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) { isStream bool, group string, other map[string]interface{}) {
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
+36
View File
@@ -106,6 +106,18 @@ func InitOptionMap() {
common.OptionMap["WaffoUnitPrice"] = strconv.FormatFloat(setting.WaffoUnitPrice, 'f', -1, 64) common.OptionMap["WaffoUnitPrice"] = strconv.FormatFloat(setting.WaffoUnitPrice, 'f', -1, 64)
common.OptionMap["WaffoMinTopUp"] = strconv.Itoa(setting.WaffoMinTopUp) common.OptionMap["WaffoMinTopUp"] = strconv.Itoa(setting.WaffoMinTopUp)
common.OptionMap["WaffoPayMethods"] = setting.WaffoPayMethods2JsonString() common.OptionMap["WaffoPayMethods"] = setting.WaffoPayMethods2JsonString()
common.OptionMap["WaffoPancakeEnabled"] = strconv.FormatBool(setting.WaffoPancakeEnabled)
common.OptionMap["WaffoPancakeSandbox"] = strconv.FormatBool(setting.WaffoPancakeSandbox)
common.OptionMap["WaffoPancakeMerchantID"] = setting.WaffoPancakeMerchantID
common.OptionMap["WaffoPancakePrivateKey"] = setting.WaffoPancakePrivateKey
common.OptionMap["WaffoPancakeWebhookPublicKey"] = setting.WaffoPancakeWebhookPublicKey
common.OptionMap["WaffoPancakeWebhookTestKey"] = setting.WaffoPancakeWebhookTestKey
common.OptionMap["WaffoPancakeStoreID"] = setting.WaffoPancakeStoreID
common.OptionMap["WaffoPancakeProductID"] = setting.WaffoPancakeProductID
common.OptionMap["WaffoPancakeReturnURL"] = setting.WaffoPancakeReturnURL
common.OptionMap["WaffoPancakeCurrency"] = setting.WaffoPancakeCurrency
common.OptionMap["WaffoPancakeUnitPrice"] = strconv.FormatFloat(setting.WaffoPancakeUnitPrice, 'f', -1, 64)
common.OptionMap["WaffoPancakeMinTopUp"] = strconv.Itoa(setting.WaffoPancakeMinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString() common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString() common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
@@ -407,6 +419,30 @@ func updateOptionMap(key string, value string) (err error) {
setting.WaffoUnitPrice, _ = strconv.ParseFloat(value, 64) setting.WaffoUnitPrice, _ = strconv.ParseFloat(value, 64)
case "WaffoMinTopUp": case "WaffoMinTopUp":
setting.WaffoMinTopUp, _ = strconv.Atoi(value) setting.WaffoMinTopUp, _ = strconv.Atoi(value)
case "WaffoPancakeEnabled":
setting.WaffoPancakeEnabled = value == "true"
case "WaffoPancakeSandbox":
setting.WaffoPancakeSandbox = value == "true"
case "WaffoPancakeMerchantID":
setting.WaffoPancakeMerchantID = value
case "WaffoPancakePrivateKey":
setting.WaffoPancakePrivateKey = value
case "WaffoPancakeWebhookPublicKey":
setting.WaffoPancakeWebhookPublicKey = value
case "WaffoPancakeWebhookTestKey":
setting.WaffoPancakeWebhookTestKey = value
case "WaffoPancakeStoreID":
setting.WaffoPancakeStoreID = value
case "WaffoPancakeProductID":
setting.WaffoPancakeProductID = value
case "WaffoPancakeReturnURL":
setting.WaffoPancakeReturnURL = value
case "WaffoPancakeCurrency":
setting.WaffoPancakeCurrency = value
case "WaffoPancakeUnitPrice":
setting.WaffoPancakeUnitPrice, _ = strconv.ParseFloat(value, 64)
case "WaffoPancakeMinTopUp":
setting.WaffoPancakeMinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio": case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value) err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId": case "GitHubClientId":
+172
View File
@@ -0,0 +1,172 @@
package model
import (
"testing"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func insertUserForPaymentGuardTest(t *testing.T, id int, quota int) {
t.Helper()
user := &User{
Id: id,
Username: "payment_guard_user",
Status: common.UserStatusEnabled,
Quota: quota,
}
require.NoError(t, DB.Create(user).Error)
}
func insertSubscriptionPlanForPaymentGuardTest(t *testing.T, id int) *SubscriptionPlan {
t.Helper()
plan := &SubscriptionPlan{
Id: id,
Title: "Guard Plan",
PriceAmount: 9.99,
Currency: "USD",
DurationUnit: SubscriptionDurationMonth,
DurationValue: 1,
Enabled: true,
TotalAmount: 1000,
}
require.NoError(t, DB.Create(plan).Error)
return plan
}
func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentMethod string) {
t.Helper()
order := &SubscriptionOrder{
UserId: userID,
PlanId: planID,
Money: 9.99,
TradeNo: tradeNo,
PaymentMethod: paymentMethod,
Status: common.TopUpStatusPending,
CreateTime: time.Now().Unix(),
}
require.NoError(t, order.Insert())
}
func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentMethod string) {
t.Helper()
topUp := &TopUp{
UserId: userID,
Amount: 2,
Money: 9.99,
TradeNo: tradeNo,
PaymentMethod: paymentMethod,
Status: common.TopUpStatusPending,
CreateTime: time.Now().Unix(),
}
require.NoError(t, topUp.Insert())
}
func getTopUpStatusForPaymentGuardTest(t *testing.T, tradeNo string) string {
t.Helper()
topUp := GetTopUpByTradeNo(tradeNo)
require.NotNil(t, topUp)
return topUp.Status
}
func countUserSubscriptionsForPaymentGuardTest(t *testing.T, userID int) int64 {
t.Helper()
var count int64
require.NoError(t, DB.Model(&UserSubscription{}).Where("user_id = ?", userID).Count(&count).Error)
return count
}
func getUserQuotaForPaymentGuardTest(t *testing.T, userID int) int {
t.Helper()
var user User
require.NoError(t, DB.Select("quota").Where("id = ?", userID).First(&user).Error)
return user.Quota
}
func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
truncateTables(t)
insertUserForPaymentGuardTest(t, 101, 0)
insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentMethodStripe)
err := RechargeWaffoPancake("waffo-pancake-guard")
require.Error(t, err)
topUp := GetTopUpByTradeNo("waffo-pancake-guard")
require.NotNil(t, topUp)
assert.Equal(t, common.TopUpStatusPending, topUp.Status)
assert.Equal(t, 0, getUserQuotaForPaymentGuardTest(t, 101))
}
func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
testCases := []struct {
name string
tradeNo string
storedPaymentMethod string
expectedPaymentMethod string
targetStatus string
}{
{
name: "stripe expire",
tradeNo: "stripe-expire-guard",
storedPaymentMethod: PaymentMethodCreem,
expectedPaymentMethod: PaymentMethodStripe,
targetStatus: common.TopUpStatusExpired,
},
{
name: "waffo failed",
tradeNo: "waffo-failed-guard",
storedPaymentMethod: PaymentMethodStripe,
expectedPaymentMethod: PaymentMethodWaffo,
targetStatus: common.TopUpStatusFailed,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
truncateTables(t)
insertUserForPaymentGuardTest(t, 150, 0)
insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentMethod)
err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentMethod, tc.targetStatus)
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
assert.Equal(t, common.TopUpStatusPending, getTopUpStatusForPaymentGuardTest(t, tc.tradeNo))
})
}
}
func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
truncateTables(t)
insertUserForPaymentGuardTest(t, 202, 0)
plan := insertSubscriptionPlanForPaymentGuardTest(t, 301)
insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentMethodStripe)
err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, "alipay")
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
order := GetSubscriptionOrderByTradeNo("sub-guard-order")
require.NotNil(t, order)
assert.Equal(t, common.TopUpStatusPending, order.Status)
assert.Zero(t, countUserSubscriptionsForPaymentGuardTest(t, 202))
topUp := GetTopUpByTradeNo("sub-guard-order")
assert.Nil(t, topUp)
}
func TestExpireSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
truncateTables(t)
insertUserForPaymentGuardTest(t, 303, 0)
plan := insertSubscriptionPlanForPaymentGuardTest(t, 401)
insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentMethodStripe)
err := ExpireSubscriptionOrder("sub-expire-guard", PaymentMethodCreem)
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
order := GetSubscriptionOrderByTradeNo("sub-expire-guard")
require.NotNil(t, order)
assert.Equal(t, common.TopUpStatusPending, order.Status)
}
-3
View File
@@ -11,9 +11,6 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
// ErrRedeemFailed is returned when redemption fails due to database error
var ErrRedeemFailed = errors.New("redeem.failed")
type Redemption struct { type Redemption struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
+10 -2
View File
@@ -505,7 +505,7 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio
} }
// Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan. // Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan.
func CompleteSubscriptionOrder(tradeNo string, providerPayload string) error { func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentMethod string) error {
if tradeNo == "" { if tradeNo == "" {
return errors.New("tradeNo is empty") return errors.New("tradeNo is empty")
} }
@@ -523,6 +523,9 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string) error {
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil { if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
return ErrSubscriptionOrderNotFound return ErrSubscriptionOrderNotFound
} }
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
return ErrPaymentMethodMismatch
}
if order.Status == common.TopUpStatusSuccess { if order.Status == common.TopUpStatusSuccess {
return nil return nil
} }
@@ -596,6 +599,8 @@ func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
topup.Money = order.Money topup.Money = order.Money
if topup.PaymentMethod == "" { if topup.PaymentMethod == "" {
topup.PaymentMethod = order.PaymentMethod topup.PaymentMethod = order.PaymentMethod
} else if topup.PaymentMethod != order.PaymentMethod {
return ErrPaymentMethodMismatch
} }
if topup.CreateTime == 0 { if topup.CreateTime == 0 {
topup.CreateTime = order.CreateTime topup.CreateTime = order.CreateTime
@@ -605,7 +610,7 @@ func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
return tx.Save(&topup).Error return tx.Save(&topup).Error
} }
func ExpireSubscriptionOrder(tradeNo string) error { func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error {
if tradeNo == "" { if tradeNo == "" {
return errors.New("tradeNo is empty") return errors.New("tradeNo is empty")
} }
@@ -618,6 +623,9 @@ func ExpireSubscriptionOrder(tradeNo string) error {
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil { if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
return ErrSubscriptionOrderNotFound return ErrSubscriptionOrderNotFound
} }
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
return ErrPaymentMethodMismatch
}
if order.Status != common.TopUpStatusPending { if order.Status != common.TopUpStatusPending {
return nil return nil
} }
+15 -1
View File
@@ -33,7 +33,17 @@ func TestMain(m *testing.M) {
} }
sqlDB.SetMaxOpenConns(1) sqlDB.SetMaxOpenConns(1)
if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil { if err := db.AutoMigrate(
&Task{},
&User{},
&Token{},
&Log{},
&Channel{},
&TopUp{},
&SubscriptionPlan{},
&SubscriptionOrder{},
&UserSubscription{},
); err != nil {
panic("failed to migrate: " + err.Error()) panic("failed to migrate: " + err.Error())
} }
@@ -48,6 +58,10 @@ func truncateTables(t *testing.T) {
DB.Exec("DELETE FROM tokens") DB.Exec("DELETE FROM tokens")
DB.Exec("DELETE FROM logs") DB.Exec("DELETE FROM logs")
DB.Exec("DELETE FROM channels") DB.Exec("DELETE FROM channels")
DB.Exec("DELETE FROM top_ups")
DB.Exec("DELETE FROM subscription_orders")
DB.Exec("DELETE FROM subscription_plans")
DB.Exec("DELETE FROM user_subscriptions")
}) })
} }
+39 -19
View File
@@ -14,7 +14,7 @@ import (
type Token struct { type Token struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id" gorm:"index"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"` Key string `json:"key" gorm:"type:varchar(128);uniqueIndex"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" ` Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"` CreatedTime int64 `json:"created_time" gorm:"bigint"`
@@ -187,19 +187,14 @@ func SearchUserTokens(userId int, keyword string, token string, offset int, limi
func ValidateUserToken(key string) (token *Token, err error) { func ValidateUserToken(key string) (token *Token, err error) {
if key == "" { if key == "" {
return nil, errors.New("未提供令牌") return nil, ErrTokenNotProvided
} }
token, err = GetTokenByKey(key, false) token, err = GetTokenByKey(key, false)
if err == nil { if err == nil {
if token.Status == common.TokenStatusExhausted { if token.Status == common.TokenStatusExhausted ||
keyPrefix := key[:3] token.Status == common.TokenStatusExpired ||
keySuffix := key[len(key)-3:] token.Status != common.TokenStatusEnabled {
return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") return token, ErrTokenInvalid
} else if token.Status == common.TokenStatusExpired {
return token, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled {
return token, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if !common.RedisEnabled { if !common.RedisEnabled {
@@ -209,29 +204,25 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysLog("failed to update token status" + err.Error()) common.SysLog("failed to update token status" + err.Error())
} }
} }
return token, errors.New("该令牌已过期") return token, ErrTokenInvalid
} }
if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled { if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted token.Status = common.TokenStatusExhausted
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
common.SysLog("failed to update token status" + err.Error()) common.SysLog("failed to update token status" + err.Error())
} }
} }
keyPrefix := key[:3] return token, ErrTokenInvalid
keySuffix := key[len(key)-3:]
return token, fmt.Errorf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)
} }
return token, nil return token, nil
} }
common.SysLog("ValidateUserToken: failed to get token: " + err.Error()) common.SysLog("ValidateUserToken: failed to get token: " + err.Error())
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("无效的令牌") return nil, ErrTokenInvalid
} else {
return nil, errors.New("无效的令牌,数据库查询出错,请联系管理员")
} }
return nil, fmt.Errorf("%w: %v", ErrDatabase, err)
} }
func GetTokenByIds(id int, userId int) (*Token, error) { func GetTokenByIds(id int, userId int) (*Token, error) {
@@ -489,3 +480,32 @@ func GetTokenKeysByIds(ids []int, userId int) ([]Token, error) {
Find(&tokens).Error Find(&tokens).Error
return tokens, err return tokens, err
} }
// InvalidateUserTokensCache 清理指定用户所有令牌在 Redis 中的缓存,
// 配合 InvalidateUserCache 使用,可在用户被禁用/删除时立即阻断其令牌的请求。
// 下一次请求将从数据库重新加载令牌及用户状态,从而立即识别出被禁用的用户。
func InvalidateUserTokensCache(userId int) error {
if !common.RedisEnabled {
return nil
}
if userId <= 0 {
return errors.New("userId 无效")
}
var tokens []Token
if err := DB.Unscoped().
Select("id", commonKeyCol).
Where("user_id = ?", userId).
Find(&tokens).Error; err != nil {
return err
}
var firstErr error
for _, t := range tokens {
if t.Key == "" {
continue
}
if err := cacheDeleteToken(t.Key); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
+174 -33
View File
@@ -12,17 +12,30 @@ import (
) )
type TopUp struct { type TopUp struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id" gorm:"index"`
Amount int64 `json:"amount"` Amount int64 `json:"amount"`
Money float64 `json:"money"` Money float64 `json:"money"`
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
CreateTime int64 `json:"create_time"` CreateTime int64 `json:"create_time"`
CompleteTime int64 `json:"complete_time"` CompleteTime int64 `json:"complete_time"`
Status string `json:"status"` Status string `json:"status"`
} }
const (
PaymentMethodStripe = "stripe"
PaymentMethodCreem = "creem"
PaymentMethodWaffo = "waffo"
PaymentMethodWaffoPancake = "waffo_pancake"
)
var (
ErrPaymentMethodMismatch = errors.New("payment method mismatch")
ErrTopUpNotFound = errors.New("topup not found")
ErrTopUpStatusInvalid = errors.New("topup status invalid")
)
func (topUp *TopUp) Insert() error { func (topUp *TopUp) Insert() error {
var err error var err error
err = DB.Create(topUp).Error err = DB.Create(topUp).Error
@@ -55,7 +68,34 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
return topUp return topUp
} }
func Recharge(referenceId string, customerId string) (err error) { func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targetStatus string) error {
if tradeNo == "" {
return errors.New("未提供支付单号")
}
refCol := "`trade_no`"
if common.UsingPostgreSQL {
refCol = `"trade_no"`
}
return DB.Transaction(func(tx *gorm.DB) error {
topUp := &TopUp{}
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error; err != nil {
return ErrTopUpNotFound
}
if expectedPaymentMethod != "" && topUp.PaymentMethod != expectedPaymentMethod {
return ErrPaymentMethodMismatch
}
if topUp.Status != common.TopUpStatusPending {
return ErrTopUpStatusInvalid
}
topUp.Status = targetStatus
return tx.Save(topUp).Error
})
}
func Recharge(referenceId string, customerId string, callerIp string) (err error) {
if referenceId == "" { if referenceId == "" {
return errors.New("未提供支付单号") return errors.New("未提供支付单号")
} }
@@ -74,6 +114,10 @@ func Recharge(referenceId string, customerId string) (err error) {
return errors.New("充值订单不存在") return errors.New("充值订单不存在")
} }
if topUp.PaymentMethod != PaymentMethodStripe {
return ErrPaymentMethodMismatch
}
if topUp.Status != common.TopUpStatusPending { if topUp.Status != common.TopUpStatusPending {
return errors.New("充值订单状态错误") return errors.New("充值订单状态错误")
} }
@@ -99,11 +143,19 @@ func Recharge(referenceId string, customerId string) (err error) {
return errors.New("充值失败,请稍后重试") return errors.New("充值失败,请稍后重试")
} }
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount)) RecordTopupLog(topUp.UserId, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount), callerIp, topUp.PaymentMethod, PaymentMethodStripe)
return nil return nil
} }
// topUpQueryWindowSeconds 限制充值记录查询的时间窗口(秒)。
const topUpQueryWindowSeconds int64 = 30 * 24 * 60 * 60
// topUpQueryCutoff 返回允许查询的最早 create_time(秒级 Unix 时间戳)。
func topUpQueryCutoff() int64 {
return common.GetTimestamp() - topUpQueryWindowSeconds
}
func GetUserTopUps(userId int, pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) { func GetUserTopUps(userId int, pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) {
// Start transaction // Start transaction
tx := DB.Begin() tx := DB.Begin()
@@ -116,15 +168,17 @@ func GetUserTopUps(userId int, pageInfo *common.PageInfo) (topups []*TopUp, tota
} }
}() }()
cutoff := topUpQueryCutoff()
// Get total count within transaction // Get total count within transaction
err = tx.Model(&TopUp{}).Where("user_id = ?", userId).Count(&total).Error err = tx.Model(&TopUp{}).Where("user_id = ? AND create_time >= ?", userId, cutoff).Count(&total).Error
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return nil, 0, err return nil, 0, err
} }
// Get paginated topups within same transaction // Get paginated topups within same transaction
err = tx.Where("user_id = ?", userId).Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Find(&topups).Error err = tx.Where("user_id = ? AND create_time >= ?", userId, cutoff).Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Find(&topups).Error
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return nil, 0, err return nil, 0, err
@@ -138,7 +192,7 @@ func GetUserTopUps(userId int, pageInfo *common.PageInfo) (topups []*TopUp, tota
return topups, total, nil return topups, total, nil
} }
// GetAllTopUps 获取全平台的充值记录(管理员使用) // GetAllTopUps 获取全平台的充值记录(管理员使用,不限制时间窗口
func GetAllTopUps(pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) { func GetAllTopUps(pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) {
tx := DB.Begin() tx := DB.Begin()
if tx.Error != nil { if tx.Error != nil {
@@ -167,6 +221,10 @@ func GetAllTopUps(pageInfo *common.PageInfo) (topups []*TopUp, total int64, err
return topups, total, nil return topups, total, nil
} }
// searchTopUpCountHardLimit 搜索充值记录时 COUNT 的安全上限,
// 防止对超大表执行无界 COUNT 触发 DoS。
const searchTopUpCountHardLimit = 10000
// SearchUserTopUps 按订单号搜索某用户的充值记录 // SearchUserTopUps 按订单号搜索某用户的充值记录
func SearchUserTopUps(userId int, keyword string, pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) { func SearchUserTopUps(userId int, keyword string, pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) {
tx := DB.Begin() tx := DB.Begin()
@@ -179,20 +237,26 @@ func SearchUserTopUps(userId int, keyword string, pageInfo *common.PageInfo) (to
} }
}() }()
query := tx.Model(&TopUp{}).Where("user_id = ?", userId) query := tx.Model(&TopUp{}).Where("user_id = ? AND create_time >= ?", userId, topUpQueryCutoff())
if keyword != "" { if keyword != "" {
like := "%%" + keyword + "%%" pattern, perr := sanitizeLikePattern(keyword)
query = query.Where("trade_no LIKE ?", like) if perr != nil {
tx.Rollback()
return nil, 0, perr
}
query = query.Where("trade_no LIKE ? ESCAPE '!'", pattern)
} }
if err = query.Count(&total).Error; err != nil { if err = query.Limit(searchTopUpCountHardLimit).Count(&total).Error; err != nil {
tx.Rollback() tx.Rollback()
return nil, 0, err common.SysError("failed to count search topups: " + err.Error())
return nil, 0, errors.New("搜索充值记录失败")
} }
if err = query.Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Find(&topups).Error; err != nil { if err = query.Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Find(&topups).Error; err != nil {
tx.Rollback() tx.Rollback()
return nil, 0, err common.SysError("failed to search topups: " + err.Error())
return nil, 0, errors.New("搜索充值记录失败")
} }
if err = tx.Commit().Error; err != nil { if err = tx.Commit().Error; err != nil {
@@ -201,7 +265,7 @@ func SearchUserTopUps(userId int, keyword string, pageInfo *common.PageInfo) (to
return topups, total, nil return topups, total, nil
} }
// SearchAllTopUps 按订单号搜索全平台充值记录(管理员使用) // SearchAllTopUps 按订单号搜索全平台充值记录(管理员使用,不限制时间窗口
func SearchAllTopUps(keyword string, pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) { func SearchAllTopUps(keyword string, pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) {
tx := DB.Begin() tx := DB.Begin()
if tx.Error != nil { if tx.Error != nil {
@@ -215,18 +279,24 @@ func SearchAllTopUps(keyword string, pageInfo *common.PageInfo) (topups []*TopUp
query := tx.Model(&TopUp{}) query := tx.Model(&TopUp{})
if keyword != "" { if keyword != "" {
like := "%%" + keyword + "%%" pattern, perr := sanitizeLikePattern(keyword)
query = query.Where("trade_no LIKE ?", like) if perr != nil {
tx.Rollback()
return nil, 0, perr
}
query = query.Where("trade_no LIKE ? ESCAPE '!'", pattern)
} }
if err = query.Count(&total).Error; err != nil { if err = query.Limit(searchTopUpCountHardLimit).Count(&total).Error; err != nil {
tx.Rollback() tx.Rollback()
return nil, 0, err common.SysError("failed to count search topups: " + err.Error())
return nil, 0, errors.New("搜索充值记录失败")
} }
if err = query.Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Find(&topups).Error; err != nil { if err = query.Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Find(&topups).Error; err != nil {
tx.Rollback() tx.Rollback()
return nil, 0, err common.SysError("failed to search topups: " + err.Error())
return nil, 0, errors.New("搜索充值记录失败")
} }
if err = tx.Commit().Error; err != nil { if err = tx.Commit().Error; err != nil {
@@ -236,7 +306,7 @@ func SearchAllTopUps(keyword string, pageInfo *common.PageInfo) (topups []*TopUp
} }
// ManualCompleteTopUp 管理员手动完成订单并给用户充值 // ManualCompleteTopUp 管理员手动完成订单并给用户充值
func ManualCompleteTopUp(tradeNo string) error { func ManualCompleteTopUp(tradeNo string, callerIp string) error {
if tradeNo == "" { if tradeNo == "" {
return errors.New("未提供订单号") return errors.New("未提供订单号")
} }
@@ -249,6 +319,7 @@ func ManualCompleteTopUp(tradeNo string) error {
var userId int var userId int
var quotaToAdd int var quotaToAdd int
var payMoney float64 var payMoney float64
var paymentMethod string
err := DB.Transaction(func(tx *gorm.DB) error { err := DB.Transaction(func(tx *gorm.DB) error {
topUp := &TopUp{} topUp := &TopUp{}
@@ -269,7 +340,7 @@ func ManualCompleteTopUp(tradeNo string) error {
// 计算应充值额度: // 计算应充值额度:
// - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit // - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit
// - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit // - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit
if topUp.PaymentMethod == "stripe" { if topUp.PaymentMethod == PaymentMethodStripe {
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart()) quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart())
} else { } else {
@@ -295,6 +366,7 @@ func ManualCompleteTopUp(tradeNo string) error {
userId = topUp.UserId userId = topUp.UserId
payMoney = topUp.Money payMoney = topUp.Money
paymentMethod = topUp.PaymentMethod
return nil return nil
}) })
@@ -303,10 +375,10 @@ func ManualCompleteTopUp(tradeNo string) error {
} }
// 事务外记录日志,避免阻塞 // 事务外记录日志,避免阻塞
RecordLog(userId, LogTypeTopup, fmt.Sprintf("管理员补单成功,充值金额: %v,支付金额:%f", logger.FormatQuota(quotaToAdd), payMoney)) RecordTopupLog(userId, fmt.Sprintf("管理员补单成功,充值金额: %v,支付金额:%f", logger.FormatQuota(quotaToAdd), payMoney), callerIp, paymentMethod, "admin")
return nil return nil
} }
func RechargeCreem(referenceId string, customerEmail string, customerName string) (err error) { func RechargeCreem(referenceId string, customerEmail string, customerName string, callerIp string) (err error) {
if referenceId == "" { if referenceId == "" {
return errors.New("未提供支付单号") return errors.New("未提供支付单号")
} }
@@ -325,6 +397,10 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
return errors.New("充值订单不存在") return errors.New("充值订单不存在")
} }
if topUp.PaymentMethod != PaymentMethodCreem {
return ErrPaymentMethodMismatch
}
if topUp.Status != common.TopUpStatusPending { if topUp.Status != common.TopUpStatusPending {
return errors.New("充值订单状态错误") return errors.New("充值订单状态错误")
} }
@@ -372,12 +448,12 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
return errors.New("充值失败,请稍后重试") return errors.New("充值失败,请稍后重试")
} }
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用Creem充值成功,充值额度: %v,支付金额:%.2f", quota, topUp.Money)) RecordTopupLog(topUp.UserId, fmt.Sprintf("使用Creem充值成功,充值额度: %v,支付金额:%.2f", quota, topUp.Money), callerIp, topUp.PaymentMethod, PaymentMethodCreem)
return nil return nil
} }
func RechargeWaffo(tradeNo string) (err error) { func RechargeWaffo(tradeNo string, callerIp string) (err error) {
if tradeNo == "" { if tradeNo == "" {
return errors.New("未提供支付单号") return errors.New("未提供支付单号")
} }
@@ -396,6 +472,10 @@ func RechargeWaffo(tradeNo string) (err error) {
return errors.New("充值订单不存在") return errors.New("充值订单不存在")
} }
if topUp.PaymentMethod != PaymentMethodWaffo {
return ErrPaymentMethodMismatch
}
if topUp.Status == common.TopUpStatusSuccess { if topUp.Status == common.TopUpStatusSuccess {
return nil // 幂等:已成功直接返回 return nil // 幂等:已成功直接返回
} }
@@ -430,7 +510,68 @@ func RechargeWaffo(tradeNo string) (err error) {
} }
if quotaToAdd > 0 { if quotaToAdd > 0 {
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("Waffo充值成功,充值额度: %v,支付金额: %.2f", logger.FormatQuota(quotaToAdd), topUp.Money)) RecordTopupLog(topUp.UserId, fmt.Sprintf("Waffo充值成功,充值额度: %v,支付金额: %.2f", logger.FormatQuota(quotaToAdd), topUp.Money), callerIp, topUp.PaymentMethod, PaymentMethodWaffo)
}
return nil
}
func RechargeWaffoPancake(tradeNo string) (err error) {
if tradeNo == "" {
return errors.New("未提供支付单号")
}
var quotaToAdd int
topUp := &TopUp{}
refCol := "`trade_no`"
if common.UsingPostgreSQL {
refCol = `"trade_no"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error
if err != nil {
return errors.New("充值订单不存在")
}
if topUp.PaymentMethod != PaymentMethodWaffoPancake {
return ErrPaymentMethodMismatch
}
if topUp.Status == common.TopUpStatusSuccess {
return nil
}
if topUp.Status != common.TopUpStatusPending {
return errors.New("充值订单状态错误")
}
quotaToAdd = int(decimal.NewFromInt(topUp.Amount).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).IntPart())
if quotaToAdd <= 0 {
return errors.New("无效的充值额度")
}
topUp.CompleteTime = common.GetTimestamp()
topUp.Status = common.TopUpStatusSuccess
if err := tx.Save(topUp).Error; err != nil {
return err
}
if err := tx.Model(&User{}).Where("id = ?", topUp.UserId).Update("quota", gorm.Expr("quota + ?", quotaToAdd)).Error; err != nil {
return err
}
return nil
})
if err != nil {
common.SysError("waffo pancake topup failed: " + err.Error())
return errors.New("充值失败,请稍后重试")
}
if quotaToAdd > 0 {
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("Waffo Pancake充值成功,充值额度: %v,支付金额: %.2f", logger.FormatQuota(quotaToAdd), topUp.Money))
} }
return nil return nil
-2
View File
@@ -10,8 +10,6 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
var ErrTwoFANotEnabled = errors.New("用户未启用2FA")
// TwoFA 用户2FA设置表 // TwoFA 用户2FA设置表
type TwoFA struct { type TwoFA struct {
Id int `json:"id" gorm:"primaryKey"` Id int `json:"id" gorm:"primaryKey"`
+23 -14
View File
@@ -523,7 +523,6 @@ func (user *User) Edit(updatePassword bool) error {
"username": newUser.Username, "username": newUser.Username,
"display_name": newUser.DisplayName, "display_name": newUser.DisplayName,
"group": newUser.Group, "group": newUser.Group,
"quota": newUser.Quota,
"remark": newUser.Remark, "remark": newUser.Remark,
} }
if updatePassword { if updatePassword {
@@ -598,13 +597,19 @@ func (user *User) ValidateAndFill() (err error) {
password := user.Password password := user.Password
username := strings.TrimSpace(user.Username) username := strings.TrimSpace(user.Username)
if username == "" || password == "" { if username == "" || password == "" {
return errors.New("用户名或密码为空") return ErrUserEmptyCredentials
}
// find by username or email
err = DB.Where("username = ? OR email = ?", username, username).First(user).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrInvalidCredentials
}
return fmt.Errorf("%w: %v", ErrDatabase, err)
} }
// find buy username or email
DB.Where("username = ? OR email = ?", username, username).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password) okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled { if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁") return ErrInvalidCredentials
} }
return nil return nil
} }
@@ -755,16 +760,20 @@ func IsAdmin(userId int) bool {
// return user.Status == common.UserStatusEnabled, nil // return user.Status == common.UserStatusEnabled, nil
//} //}
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (*User, error) {
if token == "" { if token == "" {
return nil return nil, nil
} }
token = strings.Replace(token, "Bearer ", "", 1) token = strings.Replace(token, "Bearer ", "", 1)
user = &User{} user := &User{}
if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 { err := DB.Where("access_token = ?", token).First(user).Error
return user if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("%w: %v", ErrDatabase, err)
} }
return nil return user, nil
} }
// GetUserQuota gets quota from Redis first, falls back to DB if needed // GetUserQuota gets quota from Redis first, falls back to DB if needed
@@ -896,7 +905,7 @@ func increaseUserQuota(id int, quota int) (err error) {
return err return err
} }
func DecreaseUserQuota(id int, quota int) (err error) { func DecreaseUserQuota(id int, quota int, db bool) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
@@ -906,7 +915,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
common.SysLog("failed to decrease user quota: " + err.Error()) common.SysLog("failed to decrease user quota: " + err.Error())
} }
}) })
if common.BatchUpdateEnabled { if !db && common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota) addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil return nil
} }
@@ -928,7 +937,7 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
if delta > 0 { if delta > 0 {
return IncreaseUserQuota(id, delta, false) return IncreaseUserQuota(id, delta, false)
} else { } else {
return DecreaseUserQuota(id, -delta) return DecreaseUserQuota(id, -delta, false)
} }
} }
+6
View File
@@ -57,6 +57,12 @@ func invalidateUserCache(userId int) error {
return common.RedisDelKey(getUserCacheKey(userId)) return common.RedisDelKey(getUserCacheKey(userId))
} }
// InvalidateUserCache is the exported version of invalidateUserCache.
// 供 controller 等上层包在用户状态变更(如禁用、删除、角色变更)后主动清理缓存。
func InvalidateUserCache(userId int) error {
return invalidateUserCache(userId)
}
// updateUserCache updates all user cache fields using hash // updateUserCache updates all user cache fields using hash
func updateUserCache(user User) error { func updateUserCache(user User) error {
if !common.RedisEnabled { if !common.RedisEnabled {
+6
View File
@@ -18,6 +18,7 @@ var awsModelIDMap = map[string]string{
"claude-haiku-4-5-20251001": "anthropic.claude-haiku-4-5-20251001-v1:0", "claude-haiku-4-5-20251001": "anthropic.claude-haiku-4-5-20251001-v1:0",
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0", "claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-6": "anthropic.claude-opus-4-6-v1", "claude-opus-4-6": "anthropic.claude-opus-4-6-v1",
"claude-opus-4-7": "anthropic.claude-opus-4-7",
// Nova models // Nova models
"nova-micro-v1:0": "amazon.nova-micro-v1:0", "nova-micro-v1:0": "amazon.nova-micro-v1:0",
"nova-lite-v1:0": "amazon.nova-lite-v1:0", "nova-lite-v1:0": "amazon.nova-lite-v1:0",
@@ -91,6 +92,11 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"ap": true, "ap": true,
"eu": true, "eu": true,
}, },
"anthropic.claude-opus-4-7": {
"us": true,
"ap": true,
"eu": true,
},
"anthropic.claude-haiku-4-5-20251001-v1:0": { "anthropic.claude-haiku-4-5-20251001-v1:0": {
"us": true, "us": true,
"ap": true, "ap": true,
+7
View File
@@ -26,6 +26,13 @@ var ModelList = []string{
"claude-opus-4-6-medium", "claude-opus-4-6-medium",
"claude-opus-4-6-low", "claude-opus-4-6-low",
"claude-sonnet-4-6", "claude-sonnet-4-6",
"claude-opus-4-7",
"claude-opus-4-7-max",
"claude-opus-4-7-xhigh",
"claude-opus-4-7-high",
"claude-opus-4-7-medium",
"claude-opus-4-7-low",
"claude-opus-4-7-thinking",
} }
var ChannelName = "claude" var ChannelName = "claude"
+54 -27
View File
@@ -154,33 +154,52 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
} }
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(textRequest.Model); ok && effortLevel != "" && if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(textRequest.Model); ok && effortLevel != "" &&
strings.HasPrefix(textRequest.Model, "claude-opus-4-6") { (strings.HasPrefix(textRequest.Model, "claude-opus-4-6") || strings.HasPrefix(textRequest.Model, "claude-opus-4-7")) {
claudeRequest.Model = baseModel claudeRequest.Model = baseModel
claudeRequest.Thinking = &dto.Thinking{ claudeRequest.Thinking = &dto.Thinking{
Type: "adaptive", Type: "adaptive",
} }
claudeRequest.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel)) claudeRequest.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
claudeRequest.TopP = common.GetPointer[float64](0) if strings.HasPrefix(baseModel, "claude-opus-4-7") {
claudeRequest.Temperature = common.GetPointer[float64](1.0) // Opus 4.7 rejects non-default temperature/top_p/top_k with 400
// and defaults display to "omitted"; restore the 4.6 visible summary.
claudeRequest.Thinking.Display = "summarized"
claudeRequest.Temperature = nil
claudeRequest.TopP = nil
claudeRequest.TopK = nil
} else {
claudeRequest.TopP = nil
claudeRequest.Temperature = common.GetPointer[float64](1.0)
}
} else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled && } else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(textRequest.Model, "-thinking") { strings.HasSuffix(textRequest.Model, "-thinking") {
// 因为BudgetTokens 必须大于1024 trimmedModel := strings.TrimSuffix(textRequest.Model, "-thinking")
if claudeRequest.MaxTokens == nil || *claudeRequest.MaxTokens < 1280 { if strings.HasPrefix(trimmedModel, "claude-opus-4-7") {
claudeRequest.MaxTokens = common.GetPointer[uint](1280) // Opus 4.7 rejects thinking.type="enabled"; use adaptive at high effort.
} claudeRequest.Thinking = &dto.Thinking{Type: "adaptive", Display: "summarized"}
claudeRequest.OutputConfig = json.RawMessage(`{"effort":"high"}`)
claudeRequest.Temperature = nil
claudeRequest.TopP = nil
claudeRequest.TopK = nil
} else {
// 因为BudgetTokens 必须大于1024
if claudeRequest.MaxTokens == nil || *claudeRequest.MaxTokens < 1280 {
claudeRequest.MaxTokens = common.GetPointer[uint](1280)
}
// BudgetTokens 为 max_tokens 的 80% // BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &dto.Thinking{ claudeRequest.Thinking = &dto.Thinking{
Type: "enabled", Type: "enabled",
BudgetTokens: common.GetPointer[int](int(float64(*claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)), BudgetTokens: common.GetPointer[int](int(float64(*claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
claudeRequest.TopP = nil
claudeRequest.Temperature = common.GetPointer[float64](1.0)
} }
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
claudeRequest.TopP = nil
claudeRequest.Temperature = common.GetPointer[float64](1.0)
if !model_setting.ShouldPreserveThinkingSuffix(textRequest.Model) { if !model_setting.ShouldPreserveThinkingSuffix(textRequest.Model) {
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") claudeRequest.Model = trimmedModel
} }
} }
@@ -258,7 +277,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
formatMessages = formatMessages[:len(formatMessages)-1] formatMessages = formatMessages[:len(formatMessages)-1]
} }
} }
if fmtMessage.Content == nil { if fmtMessage.Content == nil || (fmtMessage.IsStringContent() && fmtMessage.StringContent() == "") {
fmtMessage.SetStringContent("...") fmtMessage.SetStringContent("...")
} }
formatMessages = append(formatMessages, fmtMessage) formatMessages = append(formatMessages, fmtMessage)
@@ -274,14 +293,16 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
if message.Role == "system" { if message.Role == "system" {
// 根据Claude API规范,system字段使用数组格式更有通用性 // 根据Claude API规范,system字段使用数组格式更有通用性
if message.IsStringContent() { if message.IsStringContent() {
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{ if text := message.StringContent(); text != "" {
Type: "text", systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
Text: common.GetPointer[string](message.StringContent()), Type: "text",
}) Text: common.GetPointer[string](text),
})
}
} else { } else {
// 支持复合内容的system消息(虽然不常见,但需要考虑完整性) // 支持复合内容的system消息(虽然不常见,但需要考虑完整性)
for _, ctx := range message.ParseContent() { for _, ctx := range message.ParseContent() {
if ctx.Type == "text" { if ctx.Type == "text" && ctx.Text != "" {
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{ systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
Type: "text", Type: "text",
Text: common.GetPointer[string](ctx.Text), Text: common.GetPointer[string](ctx.Text),
@@ -339,16 +360,22 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
} }
} }
} else if message.IsStringContent() && message.ToolCalls == nil { } else if message.IsStringContent() && message.ToolCalls == nil {
claudeMessage.Content = message.StringContent() text := message.StringContent()
if text == "" {
text = "..."
}
claudeMessage.Content = text
} else { } else {
claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0) claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
for _, mediaMessage := range message.ParseContent() { for _, mediaMessage := range message.ParseContent() {
switch mediaMessage.Type { switch mediaMessage.Type {
case "text": case "text":
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ if mediaMessage.Text != "" {
Type: "text", claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
Text: common.GetPointer[string](mediaMessage.Text), Type: "text",
}) Text: common.GetPointer[string](mediaMessage.Text),
})
}
default: default:
source := mediaMessage.ToFileSource() source := mediaMessage.ToFileSource()
if source == nil { if source == nil {
+7 -2
View File
@@ -136,8 +136,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
task = "chat/completions" + task task = "chat/completions" + task
} }
// 特殊处理 responses API // 特殊处理 responses API(包含 compact
if info.RelayMode == relayconstant.RelayModeResponses { if info.RelayMode == relayconstant.RelayModeResponses || info.RelayMode == relayconstant.RelayModeResponsesCompact {
responsesApiVersion := "preview" responsesApiVersion := "preview"
subUrl := "/openai/v1/responses" subUrl := "/openai/v1/responses"
@@ -150,6 +150,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
responsesApiVersion = info.ChannelOtherSettings.AzureResponsesVersion responsesApiVersion = info.ChannelOtherSettings.AzureResponsesVersion
} }
// compact 模式追加 /compact
if info.RelayMode == relayconstant.RelayModeResponsesCompact {
subUrl = subUrl + "/compact"
}
requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion) requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion)
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
} }
+1
View File
@@ -44,6 +44,7 @@ var claudeModelMap = map[string]string{
"claude-haiku-4-5-20251001": "claude-haiku-4-5@20251001", "claude-haiku-4-5-20251001": "claude-haiku-4-5@20251001",
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101", "claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
"claude-opus-4-6": "claude-opus-4-6", "claude-opus-4-6": "claude-opus-4-6",
"claude-opus-4-7": "claude-opus-4-7",
} }
const anthropicVersion = "vertex-2023-10-16" const anthropicVersion = "vertex-2023-10-16"
+32 -13
View File
@@ -53,30 +53,49 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
} }
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(request.Model); ok && effortLevel != "" && if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(request.Model); ok && effortLevel != "" &&
strings.HasPrefix(request.Model, "claude-opus-4-6") { (strings.HasPrefix(request.Model, "claude-opus-4-6") || strings.HasPrefix(request.Model, "claude-opus-4-7")) {
request.Model = baseModel request.Model = baseModel
request.Thinking = &dto.Thinking{ request.Thinking = &dto.Thinking{
Type: "adaptive", Type: "adaptive",
} }
request.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel)) request.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
request.Temperature = common.GetPointer[float64](1.0) if strings.HasPrefix(request.Model, "claude-opus-4-7") {
// Opus 4.7 rejects non-default temperature/top_p/top_k with 400
// and defaults display to "omitted"; restore the 4.6 visible summary.
request.Thinking.Display = "summarized"
request.Temperature = nil
request.TopP = nil
request.TopK = nil
} else {
request.Temperature = common.GetPointer[float64](1.0)
}
info.UpstreamModelName = request.Model info.UpstreamModelName = request.Model
} else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled && } else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(request.Model, "-thinking") { strings.HasSuffix(request.Model, "-thinking") {
if request.Thinking == nil { if request.Thinking == nil {
// 因为BudgetTokens 必须大于1024 baseModel := strings.TrimSuffix(request.Model, "-thinking")
if request.MaxTokens == nil || *request.MaxTokens < 1280 { if strings.HasPrefix(baseModel, "claude-opus-4-7") {
request.MaxTokens = common.GetPointer[uint](1280) // Opus 4.7 rejects thinking.type="enabled"; use adaptive at high effort.
} request.Thinking = &dto.Thinking{Type: "adaptive", Display: "summarized"}
request.OutputConfig = json.RawMessage(`{"effort":"high"}`)
request.Temperature = nil
request.TopP = nil
request.TopK = nil
} else {
// 因为BudgetTokens 必须大于1024
if request.MaxTokens == nil || *request.MaxTokens < 1280 {
request.MaxTokens = common.GetPointer[uint](1280)
}
// BudgetTokens 为 max_tokens 的 80% // BudgetTokens 为 max_tokens 的 80%
request.Thinking = &dto.Thinking{ request.Thinking = &dto.Thinking{
Type: "enabled", Type: "enabled",
BudgetTokens: common.GetPointer[int](int(float64(*request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)), BudgetTokens: common.GetPointer[int](int(float64(*request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
request.Temperature = common.GetPointer[float64](1.0)
} }
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
request.Temperature = common.GetPointer[float64](1.0)
} }
if !model_setting.ShouldPreserveThinkingSuffix(info.OriginModelName) { if !model_setting.ShouldPreserveThinkingSuffix(info.OriginModelName) {
request.Model = strings.TrimSuffix(request.Model, "-thinking") request.Model = strings.TrimSuffix(request.Model, "-thinking")
+1
View File
@@ -32,6 +32,7 @@ var paramOverrideKeyAuditPaths = map[string]struct{}{
"upstream_model": {}, "upstream_model": {},
"service_tier": {}, "service_tier": {},
"inference_geo": {}, "inference_geo": {},
"speed": {},
} }
type paramOverrideAuditRecorder struct { type paramOverrideAuditRecorder struct {
+19 -1
View File
@@ -2038,6 +2038,8 @@ func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
input := `{ input := `{
"service_tier":"flex", "service_tier":"flex",
"inference_geo":"eu", "inference_geo":"eu",
"speed":"fast",
"cache_control":{"type":"ephemeral"},
"safety_identifier":"user-123", "safety_identifier":"user-123",
"store":true, "store":true,
"stream_options":{"include_obfuscation":false} "stream_options":{"include_obfuscation":false}
@@ -2048,7 +2050,7 @@ func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err) t.Fatalf("RemoveDisabledFields returned error: %v", err)
} }
assertJSONEqual(t, `{"store":true}`, string(out)) assertJSONEqual(t, `{"cache_control":{"type":"ephemeral"},"store":true}`, string(out))
} }
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) { func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
@@ -2067,6 +2069,22 @@ func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out)) assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out))
} }
func TestRemoveDisabledFieldsAllowSpeed(t *testing.T) {
input := `{
"speed":"fast",
"store":true
}`
settings := dto.ChannelOtherSettings{
AllowSpeed: true,
}
out, err := RemoveDisabledFields([]byte(input), settings, false)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
assertJSONEqual(t, `{"speed":"fast","store":true}`, string(out))
}
func TestApplyParamOverrideWithRelayInfoRecordsOperationAuditInDebugMode(t *testing.T) { func TestApplyParamOverrideWithRelayInfoRecordsOperationAuditInDebugMode(t *testing.T) {
originalDebugEnabled := common2.DebugEnabled originalDebugEnabled := common2.DebugEnabled
common2.DebugEnabled = true common2.DebugEnabled = true
+9
View File
@@ -444,6 +444,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
if request != nil { if request != nil {
isStream = request.IsStream(c) isStream = request.IsStream(c)
} }
c.Set(string(constant.ContextKeyIsStream), isStream)
// firstResponseTime = time.Now() - 1 second // firstResponseTime = time.Now() - 1 second
@@ -776,6 +777,7 @@ func FailTaskInfo(reason string) *TaskInfo {
// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段 // RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
// service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持) // service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
// inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤) // inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
// speed: Claude 推理速度模式字段(仅 Claude 支持,默认过滤)
// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用) // store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
// safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私) // safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
// stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持) // stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
@@ -804,6 +806,13 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
} }
} }
// 默认移除 speed,除非明确允许(避免意外切换 Claude 推理速度模式)
if !channelOtherSettings.AllowSpeed {
if _, exists := data["speed"]; exists {
delete(data, "speed")
}
}
// 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用) // 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
if channelOtherSettings.DisableStore { if channelOtherSettings.DisableStore {
if _, exists := data["store"]; exists { if _, exists := data["store"]; exists {
+18 -2
View File
@@ -5,6 +5,7 @@ import (
"github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/pkg/billingexpr" "github.com/QuantumNous/new-api/pkg/billingexpr"
relaycommon "github.com/QuantumNous/new-api/relay/common" relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/billing_setting" "github.com/QuantumNous/new-api/setting/billing_setting"
@@ -15,6 +16,21 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func modelPriceNotConfiguredError(modelName string, userId int) error {
if model.IsAdmin(userId) {
return fmt.Errorf(
"模型 %s 的价格未配置。请前往「系统设置 → 运营设置」开启自用模式,或在「系统设置 → 分组与模型定价设置」中为该模型配置价格;"+
"Model %s price not configured. Go to System Settings → Operation Settings to enable self-use mode, or configure the model price in System Settings → Group & Model Pricing.",
modelName, modelName,
)
}
return fmt.Errorf(
"模型 %s 的价格尚未由管理员配置,暂时无法使用,请联系站点管理员开启该模型;"+
"Model %s has not been priced by the administrator yet. Please contact the site administrator to enable this model.",
modelName, modelName,
)
}
// https://docs.claude.com/en/docs/build-with-claude/prompt-caching#1-hour-cache-duration // https://docs.claude.com/en/docs/build-with-claude/prompt-caching#1-hour-cache-duration
const claudeCacheCreation1hMultiplier = 6 / 3.75 const claudeCacheCreation1hMultiplier = 6 / 3.75
@@ -82,7 +98,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
acceptUnsetRatio = true acceptUnsetRatio = true
} }
if !acceptUnsetRatio { if !acceptUnsetRatio {
return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName) return types.PriceData{}, modelPriceNotConfiguredError(matchName, info.UserId)
} }
} }
completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName) completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
@@ -168,7 +184,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) (types
acceptUnsetRatio = true acceptUnsetRatio = true
} }
if !ratioSuccess && !acceptUnsetRatio { if !ratioSuccess && !acceptUnsetRatio {
return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName) return types.PriceData{}, modelPriceNotConfiguredError(matchName, info.UserId)
} }
} }
} }
+1 -1
View File
@@ -143,7 +143,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if err != nil { if err != nil {
info.OriginModelName = originModelName info.OriginModelName = originModelName
info.PriceData = originPriceData info.PriceData = originPriceData
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry(), types.ErrOptionWithStatusCode(http.StatusBadRequest))
} }
service.PostTextConsumeQuota(c, info, usageDto, nil) service.PostTextConsumeQuota(c, info, usageDto, nil)
+4
View File
@@ -49,6 +49,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.POST("/stripe/webhook", controller.StripeWebhook) apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
apiRouter.POST("/creem/webhook", controller.CreemWebhook) apiRouter.POST("/creem/webhook", controller.CreemWebhook)
apiRouter.POST("/waffo/webhook", controller.WaffoWebhook) apiRouter.POST("/waffo/webhook", controller.WaffoWebhook)
//apiRouter.POST("/waffo-pancake/webhook", controller.WaffoPancakeWebhook)
// Universal secure verification routes // Universal secure verification routes
apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify) apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify)
@@ -90,7 +91,10 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.POST("/stripe/pay", middleware.CriticalRateLimit(), controller.RequestStripePay) selfRoute.POST("/stripe/pay", middleware.CriticalRateLimit(), controller.RequestStripePay)
selfRoute.POST("/stripe/amount", controller.RequestStripeAmount) selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
selfRoute.POST("/creem/pay", middleware.CriticalRateLimit(), controller.RequestCreemPay) selfRoute.POST("/creem/pay", middleware.CriticalRateLimit(), controller.RequestCreemPay)
selfRoute.POST("/waffo/amount", controller.RequestWaffoAmount)
selfRoute.POST("/waffo/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPay) selfRoute.POST("/waffo/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPay)
//selfRoute.POST("/waffo-pancake/amount", controller.RequestWaffoPancakeAmount)
//selfRoute.POST("/waffo-pancake/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPancakePay)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota) selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
selfRoute.PUT("/setting", controller.UpdateUserSetting) selfRoute.PUT("/setting", controller.UpdateUserSetting)
+1 -1
View File
@@ -232,7 +232,7 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro
func (s *BillingSession) reserveFunding(delta int) error { func (s *BillingSession) reserveFunding(delta int) error {
switch funding := s.funding.(type) { switch funding := s.funding.(type) {
case *WalletFunding: case *WalletFunding:
if err := model.DecreaseUserQuota(funding.userId, delta); err != nil { if err := model.DecreaseUserQuota(funding.userId, delta, false); err != nil {
return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
} }
funding.consumed += delta funding.consumed += delta
+1 -38
View File
@@ -2,11 +2,9 @@ package service
import ( import (
"fmt" "fmt"
"net/http"
"strings" "strings"
"github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/setting/operation_setting"
@@ -44,7 +42,7 @@ func EnableChannel(channelId int, usingKey string, channelName string) {
} }
} }
func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool { func ShouldDisableChannel(err *types.NewAPIError) bool {
if !common.AutomaticDisableChannelEnabled { if !common.AutomaticDisableChannelEnabled {
return false return false
} }
@@ -60,41 +58,6 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
if operation_setting.ShouldDisableByStatusCode(err.StatusCode) { if operation_setting.ShouldDisableByStatusCode(err.StatusCode) {
return true return true
} }
//if err.StatusCode == http.StatusUnauthorized {
// return true
//}
if err.StatusCode == http.StatusForbidden {
switch channelType {
case constant.ChannelTypeGemini:
return true
}
}
oaiErr := err.ToOpenAIError()
switch oaiErr.Code {
case "invalid_api_key":
return true
case "account_deactivated":
return true
case "billing_not_active":
return true
case "pre_consume_token_quota_failed":
return true
case "Arrearage":
return true
}
switch oaiErr.Type {
case "insufficient_quota":
return true
case "insufficient_user_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
return true
}
lowerMessage := strings.ToLower(err.Error()) lowerMessage := strings.ToLower(err.Error())
search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true) search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true)
+9 -1
View File
@@ -28,6 +28,10 @@ var (
codexCredentialRefreshRunning atomic.Bool codexCredentialRefreshRunning atomic.Bool
) )
func shouldAutoRefreshCodexChannelStatus(status int) bool {
return status == common.ChannelStatusEnabled || status == common.ChannelStatusAutoDisabled
}
func StartCodexCredentialAutoRefreshTask() { func StartCodexCredentialAutoRefreshTask() {
codexCredentialRefreshOnce.Do(func() { codexCredentialRefreshOnce.Do(func() {
if !common.IsMasterNode { if !common.IsMasterNode {
@@ -65,7 +69,11 @@ func runCodexCredentialAutoRefreshOnce() {
var channels []*model.Channel var channels []*model.Channel
err := model.DB. err := model.DB.
Select("id", "name", "key", "status", "channel_info"). Select("id", "name", "key", "status", "channel_info").
Where("type = ? AND status = 1", constant.ChannelTypeCodex). Where("type = ? AND (status = ? OR status = ?)",
constant.ChannelTypeCodex,
common.ChannelStatusEnabled,
common.ChannelStatusAutoDisabled,
).
Order("id asc"). Order("id asc").
Limit(codexCredentialRefreshBatchSize). Limit(codexCredentialRefreshBatchSize).
Offset(offset). Offset(offset).
+2 -2
View File
@@ -37,7 +37,7 @@ func (w *WalletFunding) PreConsume(amount int) error {
if amount <= 0 { if amount <= 0 {
return nil return nil
} }
if err := model.DecreaseUserQuota(w.userId, amount); err != nil { if err := model.DecreaseUserQuota(w.userId, amount, false); err != nil {
return err return err
} }
w.consumed = amount w.consumed = amount
@@ -49,7 +49,7 @@ func (w *WalletFunding) Settle(delta int) error {
return nil return nil
} }
if delta > 0 { if delta > 0 {
return model.DecreaseUserQuota(w.userId, delta) return model.DecreaseUserQuota(w.userId, delta, false)
} }
return model.IncreaseUserQuota(w.userId, -delta, false) return model.IncreaseUserQuota(w.userId, -delta, false)
} }
+1 -1
View File
@@ -413,7 +413,7 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
} else { } else {
// Wallet // Wallet
if quota > 0 { if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota) err = model.DecreaseUserQuota(relayInfo.UserId, quota, false)
} else { } else {
err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false) err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
} }
+1 -1
View File
@@ -90,7 +90,7 @@ func taskAdjustFunding(task *model.Task, delta int) error {
return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta)) return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
} }
if delta > 0 { if delta > 0 {
return model.DecreaseUserQuota(task.UserId, delta) return model.DecreaseUserQuota(task.UserId, delta, false)
} }
return model.IncreaseUserQuota(task.UserId, -delta, false) return model.IncreaseUserQuota(task.UserId, -delta, false)
} }
+2
View File
@@ -42,6 +42,7 @@ func TestMain(m *testing.M) {
&model.Token{}, &model.Token{},
&model.Log{}, &model.Log{},
&model.Channel{}, &model.Channel{},
&model.TopUp{},
&model.UserSubscription{}, &model.UserSubscription{},
); err != nil { ); err != nil {
panic("failed to migrate: " + err.Error()) panic("failed to migrate: " + err.Error())
@@ -62,6 +63,7 @@ func truncate(t *testing.T) {
model.DB.Exec("DELETE FROM tokens") model.DB.Exec("DELETE FROM tokens")
model.DB.Exec("DELETE FROM logs") model.DB.Exec("DELETE FROM logs")
model.DB.Exec("DELETE FROM channels") model.DB.Exec("DELETE FROM channels")
model.DB.Exec("DELETE FROM top_ups")
model.DB.Exec("DELETE FROM user_subscriptions") model.DB.Exec("DELETE FROM user_subscriptions")
}) })
} }
+398
View File
@@ -0,0 +1,398 @@
package service
import (
"bytes"
"context"
"crypto"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"io"
"math"
"net/http"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
)
const (
waffoPancakeAuthBaseURL = "https://waffo-pancake-auth-service.vercel.app"
waffoPancakeCheckoutPath = "/v1/actions/checkout/create-session"
waffoPancakeDefaultTolerance = 5 * time.Minute
)
type WaffoPancakePriceSnapshot struct {
Amount string `json:"amount"`
TaxIncluded bool `json:"taxIncluded"`
TaxCategory string `json:"taxCategory"`
}
type WaffoPancakeCreateSessionParams struct {
StoreID string `json:"storeId"`
ProductID string `json:"productId"`
ProductType string `json:"productType"`
Currency string `json:"currency"`
PriceSnapshot *WaffoPancakePriceSnapshot `json:"priceSnapshot,omitempty"`
BuyerEmail string `json:"buyerEmail,omitempty"`
SuccessURL string `json:"successUrl,omitempty"`
ExpiresInSeconds *int `json:"expiresInSeconds,omitempty"`
}
type WaffoPancakeCheckoutSession struct {
SessionID string `json:"sessionId"`
CheckoutURL string `json:"checkoutUrl"`
ExpiresAt string `json:"expiresAt"`
OrderID string `json:"orderId"`
}
type waffoPancakeAPIError struct {
Message string `json:"message"`
Layer string `json:"layer"`
}
type waffoPancakeCreateSessionResponse struct {
Data *WaffoPancakeCheckoutSession `json:"data"`
Errors []waffoPancakeAPIError `json:"errors"`
}
type waffoPancakeWebhookData struct {
ID string `json:"id"`
OrderID string `json:"orderId"`
BuyerEmail string `json:"buyerEmail"`
Currency string `json:"currency"`
Amount dto.StringValue `json:"amount"`
TaxAmount dto.StringValue `json:"taxAmount"`
ProductName string `json:"productName"`
}
type waffoPancakeWebhookEvent struct {
ID string `json:"id"`
Timestamp string `json:"timestamp"`
EventType string `json:"eventType"`
EventID string `json:"eventId"`
StoreID string `json:"storeId"`
Mode string `json:"mode"`
Data waffoPancakeWebhookData `json:"data"`
}
func (e *waffoPancakeWebhookEvent) NormalizedEventType() string {
if e == nil {
return ""
}
return e.EventType
}
func CreateWaffoPancakeCheckoutSession(ctx context.Context, params *WaffoPancakeCreateSessionParams) (*WaffoPancakeCheckoutSession, error) {
if params == nil {
return nil, fmt.Errorf("missing checkout params")
}
body, err := common.Marshal(params)
if err != nil {
return nil, fmt.Errorf("marshal Waffo Pancake checkout payload: %w", err)
}
privateKey, err := normalizeRSAPrivateKey(setting.WaffoPancakePrivateKey)
if err != nil {
return nil, err
}
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
signature, err := signWaffoPancakeRequest(http.MethodPost, waffoPancakeCheckoutPath, timestamp, string(body), privateKey)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, waffoPancakeAuthBaseURL+waffoPancakeCheckoutPath, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("build Waffo Pancake checkout request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Merchant-Id", setting.WaffoPancakeMerchantID)
req.Header.Set("X-Timestamp", timestamp)
req.Header.Set("X-Signature", signature)
if setting.WaffoPancakeSandbox {
req.Header.Set("X-Environment", "test")
} else {
req.Header.Set("X-Environment", "prod")
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request Waffo Pancake checkout session: %w", err)
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read Waffo Pancake checkout response: %w", err)
}
var result waffoPancakeCreateSessionResponse
if err := common.Unmarshal(responseBody, &result); err != nil {
return nil, fmt.Errorf("decode Waffo Pancake checkout response: %w", err)
}
if resp.StatusCode >= http.StatusBadRequest {
if len(result.Errors) > 0 {
return nil, fmt.Errorf("Waffo Pancake error (%d): %s", resp.StatusCode, result.Errors[0].Message)
}
return nil, fmt.Errorf("Waffo Pancake checkout request failed with status %d", resp.StatusCode)
}
if len(result.Errors) > 0 {
return nil, fmt.Errorf("Waffo Pancake error: %s", result.Errors[0].Message)
}
if result.Data == nil || result.Data.CheckoutURL == "" || strings.TrimSpace(result.Data.SessionID) == "" {
return nil, fmt.Errorf("Waffo Pancake returned empty checkout session")
}
return result.Data, nil
}
func VerifyConfiguredWaffoPancakeWebhook(payload string, signatureHeader string) (*waffoPancakeWebhookEvent, error) {
environment := resolveWaffoPancakeWebhookEnvironment(payload)
return verifyWaffoPancakeWebhook(payload, signatureHeader, environment)
}
func ResolveWaffoPancakeTradeNo(event *waffoPancakeWebhookEvent) (string, error) {
if event == nil {
return "", fmt.Errorf("missing webhook event")
}
if tradeNo := strings.TrimSpace(event.Data.OrderID); tradeNo != "" {
topUp := model.GetTopUpByTradeNo(tradeNo)
if topUp != nil && topUp.PaymentMethod == model.PaymentMethodWaffoPancake {
return tradeNo, nil
}
return "", fmt.Errorf("waffo pancake order not found for webhook orderId=%s", tradeNo)
}
return "", fmt.Errorf("missing webhook orderId")
}
func normalizeRSAPrivateKey(raw string) (string, error) {
return normalizePEMKey(raw, "PRIVATE KEY", "RSA PRIVATE KEY")
}
func normalizeRSAPublicKey(raw string) (string, error) {
return normalizePEMKey(raw, "PUBLIC KEY", "RSA PUBLIC KEY")
}
func normalizePEMKey(raw string, pkcs8Type string, pkcs1Type string) (string, error) {
if strings.TrimSpace(raw) == "" {
return "", fmt.Errorf("%s is empty", strings.ToLower(pkcs8Type))
}
normalized := strings.TrimSpace(strings.ReplaceAll(raw, `\n`, "\n"))
if strings.Contains(normalized, "BEGIN ") {
block, _ := pem.Decode([]byte(normalized))
if block == nil {
return "", fmt.Errorf("invalid PEM encoded %s", strings.ToLower(pkcs8Type))
}
return string(pem.EncodeToMemory(block)), nil
}
der, err := base64.StdEncoding.DecodeString(strings.ReplaceAll(normalized, "\n", ""))
if err != nil {
return "", fmt.Errorf("invalid base64 encoded %s: %w", strings.ToLower(pkcs8Type), err)
}
pemType := pkcs8Type
if pkcs8Type == "PRIVATE KEY" {
if _, err := x509.ParsePKCS8PrivateKey(der); err != nil {
if _, err := x509.ParsePKCS1PrivateKey(der); err == nil {
pemType = pkcs1Type
} else {
return "", fmt.Errorf("invalid RSA private key")
}
}
} else {
if _, err := x509.ParsePKIXPublicKey(der); err != nil {
if _, err := x509.ParsePKCS1PublicKey(der); err == nil {
pemType = pkcs1Type
} else {
return "", fmt.Errorf("invalid RSA public key")
}
}
}
return string(pem.EncodeToMemory(&pem.Block{Type: pemType, Bytes: der})), nil
}
func signWaffoPancakeRequest(method string, path string, timestamp string, body string, privateKeyPEM string) (string, error) {
block, _ := pem.Decode([]byte(privateKeyPEM))
if block == nil {
return "", fmt.Errorf("invalid RSA private key PEM")
}
var privateKey *rsa.PrivateKey
switch block.Type {
case "PRIVATE KEY":
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", fmt.Errorf("parse PKCS#8 private key: %w", err)
}
parsed, ok := key.(*rsa.PrivateKey)
if !ok {
return "", fmt.Errorf("private key is not RSA")
}
privateKey = parsed
case "RSA PRIVATE KEY":
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return "", fmt.Errorf("parse PKCS#1 private key: %w", err)
}
privateKey = key
default:
return "", fmt.Errorf("unsupported private key type: %s", block.Type)
}
canonicalRequest := buildWaffoPancakeCanonicalRequest(method, path, timestamp, body)
digest := sha256.Sum256([]byte(canonicalRequest))
signature, err := rsa.SignPKCS1v15(nil, privateKey, crypto.SHA256, digest[:])
if err != nil {
return "", fmt.Errorf("sign Waffo Pancake request: %w", err)
}
return base64.StdEncoding.EncodeToString(signature), nil
}
func buildWaffoPancakeCanonicalRequest(method string, path string, timestamp string, body string) string {
bodyHash := sha256.Sum256([]byte(body))
return fmt.Sprintf(
"%s\n%s\n%s\n%s",
strings.ToUpper(method),
path,
timestamp,
base64.StdEncoding.EncodeToString(bodyHash[:]),
)
}
func verifyWaffoPancakeWebhook(payload string, signatureHeader string, environment string) (*waffoPancakeWebhookEvent, error) {
if signatureHeader == "" {
return nil, fmt.Errorf("missing X-Waffo-Signature header")
}
timestampPart, signaturePart := parseWaffoPancakeSignatureHeader(signatureHeader)
if timestampPart == "" || signaturePart == "" {
return nil, fmt.Errorf("malformed X-Waffo-Signature header")
}
timestampMs, err := strconv.ParseInt(timestampPart, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid timestamp in X-Waffo-Signature header")
}
if math.Abs(float64(time.Now().UnixMilli()-timestampMs)) > float64(waffoPancakeDefaultTolerance.Milliseconds()) {
return nil, fmt.Errorf("webhook timestamp outside tolerance window")
}
signatureInput := fmt.Sprintf("%s.%s", timestampPart, payload)
if err := verifyWaffoPancakeWebhookWithKey(signatureInput, signaturePart, resolveWaffoPancakeWebhookPublicKey(environment)); err != nil {
return nil, fmt.Errorf("invalid webhook signature")
}
var event waffoPancakeWebhookEvent
if err := common.Unmarshal([]byte(payload), &event); err != nil {
return nil, fmt.Errorf("parse Waffo Pancake webhook payload: %w", err)
}
return &event, nil
}
func parseWaffoPancakeSignatureHeader(header string) (string, string) {
var timestampPart string
var signaturePart string
for _, pair := range strings.Split(header, ",") {
key, value, found := strings.Cut(strings.TrimSpace(pair), "=")
if !found {
continue
}
switch key {
case "t":
timestampPart = value
case "v1":
signaturePart = value
}
}
return timestampPart, signaturePart
}
func resolveWaffoPancakeWebhookEnvironment(payload string) string {
var envelope struct {
Mode string `json:"mode"`
}
if err := common.Unmarshal([]byte(payload), &envelope); err != nil {
if setting.WaffoPancakeSandbox {
return "test"
}
return "prod"
}
switch strings.ToLower(strings.TrimSpace(envelope.Mode)) {
case "test":
return "test"
case "prod":
return "prod"
default:
if setting.WaffoPancakeSandbox {
return "test"
}
return "prod"
}
}
func resolveWaffoPancakeWebhookPublicKey(environment string) string {
if environment == "prod" {
return strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
}
return strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
}
func verifyWaffoPancakeWebhookWithKey(signatureInput string, signaturePart string, rawPublicKey string) error {
publicKeyPEM, err := normalizeRSAPublicKey(rawPublicKey)
if err != nil {
return err
}
block, _ := pem.Decode([]byte(publicKeyPEM))
if block == nil {
return fmt.Errorf("invalid RSA public key PEM")
}
var publicKey *rsa.PublicKey
switch block.Type {
case "PUBLIC KEY":
key, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("parse PKIX public key: %w", err)
}
parsed, ok := key.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("public key is not RSA")
}
publicKey = parsed
case "RSA PUBLIC KEY":
key, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("parse PKCS#1 public key: %w", err)
}
publicKey = key
default:
return fmt.Errorf("unsupported public key type: %s", block.Type)
}
signature, err := base64.StdEncoding.DecodeString(signaturePart)
if err != nil {
return fmt.Errorf("decode webhook signature: %w", err)
}
digest := sha256.Sum256([]byte(signatureInput))
if err := rsa.VerifyPKCS1v15(publicKey, crypto.SHA256, digest[:], signature); err != nil {
return fmt.Errorf("verify webhook signature: %w", err)
}
return nil
}
+157
View File
@@ -0,0 +1,157 @@
package service
import (
"fmt"
"strings"
"testing"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/glebarez/sqlite"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func setupWaffoPancakeTestDB(t *testing.T) *gorm.DB {
t.Helper()
common.UsingSQLite = true
common.UsingMySQL = false
common.UsingPostgreSQL = false
common.RedisEnabled = false
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
require.NoError(t, err)
model.DB = db
model.LOG_DB = db
require.NoError(t, db.AutoMigrate(&model.User{}, &model.TopUp{}))
t.Cleanup(func() {
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
})
return db
}
func TestWaffoPancakeCreateSessionResponseParsesDocumentedPayload(t *testing.T) {
var result waffoPancakeCreateSessionResponse
err := common.Unmarshal([]byte(`{
"data": {
"sessionId": "cs_550e8400-e29b-41d4-a716-446655440000",
"checkoutUrl": "https://checkout.waffo.ai/my-store-abc123/checkout/cs_550e8400-e29b-41d4-a716-446655440000",
"expiresAt": "2026-01-22T10:30:00.000Z"
}
}`), &result)
require.NoError(t, err)
require.NotNil(t, result.Data)
require.Equal(t, "cs_550e8400-e29b-41d4-a716-446655440000", result.Data.SessionID)
require.Empty(t, result.Data.OrderID)
}
func TestResolveWaffoPancakeTradeNo_UsesWebhookOrderIDWhenLocalOrderExists(t *testing.T) {
db := setupWaffoPancakeTestDB(t)
topUp := &model.TopUp{
UserId: 1,
Amount: 10,
Money: 29,
TradeNo: "ORD_5dXBtmF2HLlHfbPNm0Wcnz",
PaymentMethod: model.PaymentMethodWaffoPancake,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
require.NoError(t, db.Create(topUp).Error)
tradeNo, err := ResolveWaffoPancakeTradeNo(&waffoPancakeWebhookEvent{
Data: waffoPancakeWebhookData{
OrderID: "ORD_5dXBtmF2HLlHfbPNm0Wcnz",
},
})
require.NoError(t, err)
require.Equal(t, "ORD_5dXBtmF2HLlHfbPNm0Wcnz", tradeNo)
}
func TestResolveWaffoPancakeTradeNo_FailsWhenWebhookOrderIDIsUnknown(t *testing.T) {
db := setupWaffoPancakeTestDB(t)
user := &model.User{
Id: 42,
Email: "buyer@example.com",
Username: "buyer",
Status: common.UserStatusEnabled,
}
require.NoError(t, db.Create(user).Error)
topUp := &model.TopUp{
UserId: user.Id,
Amount: 10,
Money: 29,
TradeNo: "WAFFO_PANCAKE-42-123456-abc123",
PaymentMethod: model.PaymentMethodWaffoPancake,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
require.NoError(t, db.Create(topUp).Error)
tradeNo, err := ResolveWaffoPancakeTradeNo(&waffoPancakeWebhookEvent{
Data: waffoPancakeWebhookData{
OrderID: "ORD_unknown",
BuyerEmail: user.Email,
Amount: "29.00",
},
})
require.Error(t, err)
require.Empty(t, tradeNo)
}
func TestResolveWaffoPancakeWebhookEnvironment(t *testing.T) {
originalSandbox := setting.WaffoPancakeSandbox
t.Cleanup(func() {
setting.WaffoPancakeSandbox = originalSandbox
})
testCases := []struct {
name string
payload string
expected string
sandbox bool
}{
{
name: "test mode",
payload: `{"mode":"test"}`,
expected: "test",
},
{
name: "prod mode",
payload: `{"mode":"prod"}`,
expected: "prod",
},
{
name: "missing mode falls back to sandbox",
payload: `{}`,
expected: "test",
sandbox: true,
},
{
name: "invalid mode falls back to prod",
payload: `{"mode":"staging"}`,
expected: "prod",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setting.WaffoPancakeSandbox = tc.sandbox
environment := resolveWaffoPancakeWebhookEnvironment(tc.payload)
require.Equal(t, tc.expected, environment)
})
}
}
+16
View File
@@ -0,0 +1,16 @@
package setting
var (
WaffoPancakeEnabled bool
WaffoPancakeSandbox bool
WaffoPancakeMerchantID string
WaffoPancakePrivateKey string
WaffoPancakeWebhookPublicKey string
WaffoPancakeWebhookTestKey string
WaffoPancakeStoreID string
WaffoPancakeProductID string
WaffoPancakeReturnURL string
WaffoPancakeCurrency string = "USD"
WaffoPancakeUnitPrice float64 = 1.0
WaffoPancakeMinTopUp int = 1
)
+14
View File
@@ -64,6 +64,13 @@ var defaultCacheRatio = map[string]float64{
"claude-opus-4-6-high": 0.1, "claude-opus-4-6-high": 0.1,
"claude-opus-4-6-medium": 0.1, "claude-opus-4-6-medium": 0.1,
"claude-opus-4-6-low": 0.1, "claude-opus-4-6-low": 0.1,
"claude-opus-4-7": 0.1,
"claude-opus-4-7-thinking": 0.1,
"claude-opus-4-7-max": 0.1,
"claude-opus-4-7-xhigh": 0.1,
"claude-opus-4-7-high": 0.1,
"claude-opus-4-7-medium": 0.1,
"claude-opus-4-7-low": 0.1,
} }
var defaultCreateCacheRatio = map[string]float64{ var defaultCreateCacheRatio = map[string]float64{
@@ -92,6 +99,13 @@ var defaultCreateCacheRatio = map[string]float64{
"claude-opus-4-6-high": 1.25, "claude-opus-4-6-high": 1.25,
"claude-opus-4-6-medium": 1.25, "claude-opus-4-6-medium": 1.25,
"claude-opus-4-6-low": 1.25, "claude-opus-4-6-low": 1.25,
"claude-opus-4-7": 1.25,
"claude-opus-4-7-thinking": 1.25,
"claude-opus-4-7-max": 1.25,
"claude-opus-4-7-xhigh": 1.25,
"claude-opus-4-7-high": 1.25,
"claude-opus-4-7-medium": 1.25,
"claude-opus-4-7-low": 1.25,
} }
//var defaultCreateCacheRatio = map[string]float64{} //var defaultCreateCacheRatio = map[string]float64{}
+6
View File
@@ -146,6 +146,12 @@ var defaultModelRatio = map[string]float64{
"claude-opus-4-6-high": 2.5, "claude-opus-4-6-high": 2.5,
"claude-opus-4-6-medium": 2.5, "claude-opus-4-6-medium": 2.5,
"claude-opus-4-6-low": 2.5, "claude-opus-4-6-low": 2.5,
"claude-opus-4-7": 2.5,
"claude-opus-4-7-max": 2.5,
"claude-opus-4-7-xhigh": 2.5,
"claude-opus-4-7-high": 2.5,
"claude-opus-4-7-medium": 2.5,
"claude-opus-4-7-low": 2.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens "claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"claude-opus-4-20250514": 7.5, "claude-opus-4-20250514": 7.5,
"claude-opus-4-1-20250805": 7.5, "claude-opus-4-1-20250805": 7.5,
+1 -1
View File
@@ -6,7 +6,7 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
) )
var EffortSuffixes = []string{"-max", "-high", "-medium", "-low", "-minimal"} var EffortSuffixes = []string{"-max", "-xhigh", "-high", "-medium", "-low", "-minimal"}
// TrimEffortSuffix -> modelName level(low) exists // TrimEffortSuffix -> modelName level(low) exists
func TrimEffortSuffix(modelName string) (string, string, bool) { func TrimEffortSuffix(modelName string) (string, string, bool) {
+6
View File
@@ -390,6 +390,12 @@ func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
} }
} }
func ErrOptionWithStatusCode(statusCode int) NewAPIErrorOptions {
return func(e *NewAPIError) {
e.StatusCode = statusCode
}
}
func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions { func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions {
return func(e *NewAPIError) { return func(e *NewAPIError) {
if common.DebugEnabled { if common.DebugEnabled {
+3 -4
View File
@@ -1,6 +1,5 @@
{ {
"lockfileVersion": 1, "lockfileVersion": 1,
"configVersion": 0,
"workspaces": { "workspaces": {
"": { "": {
"name": "react-template", "name": "react-template",
@@ -11,7 +10,7 @@
"@visactor/react-vchart": "~1.8.8", "@visactor/react-vchart": "~1.8.8",
"@visactor/vchart": "~1.8.8", "@visactor/vchart": "~1.8.8",
"@visactor/vchart-semi-theme": "~1.8.8", "@visactor/vchart-semi-theme": "~1.8.8",
"axios": "1.13.5", "axios": "1.15.0",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"dayjs": "^1.11.11", "dayjs": "^1.11.11",
"history": "^5.3.0", "history": "^5.3.0",
@@ -777,7 +776,7 @@
"autoprefixer": ["autoprefixer@10.4.21", "", { "dependencies": { "browserslist": "^4.24.4", "caniuse-lite": "^1.0.30001702", "fraction.js": "^4.3.7", "normalize-range": "^0.1.2", "picocolors": "^1.1.1", "postcss-value-parser": "^4.2.0" }, "peerDependencies": { "postcss": "^8.1.0" }, "bin": { "autoprefixer": "bin/autoprefixer" } }, "sha512-O+A6LWV5LDHSJD3LjHYoNi4VLsj/Whi7k6zG12xTYaU4cQ8oxQGckXNX8cRHK5yOZ/ppVHe0ZBXGzSV9jXdVbQ=="], "autoprefixer": ["autoprefixer@10.4.21", "", { "dependencies": { "browserslist": "^4.24.4", "caniuse-lite": "^1.0.30001702", "fraction.js": "^4.3.7", "normalize-range": "^0.1.2", "picocolors": "^1.1.1", "postcss-value-parser": "^4.2.0" }, "peerDependencies": { "postcss": "^8.1.0" }, "bin": { "autoprefixer": "bin/autoprefixer" } }, "sha512-O+A6LWV5LDHSJD3LjHYoNi4VLsj/Whi7k6zG12xTYaU4cQ8oxQGckXNX8cRHK5yOZ/ppVHe0ZBXGzSV9jXdVbQ=="],
"axios": ["axios@1.13.5", "", { "dependencies": { "follow-redirects": "^1.15.11", "form-data": "^4.0.5", "proxy-from-env": "^1.1.0" } }, "sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q=="], "axios": ["axios@1.15.0", "", { "dependencies": { "follow-redirects": "^1.15.11", "form-data": "^4.0.5", "proxy-from-env": "^2.1.0" } }, "sha512-wWyJDlAatxk30ZJer+GeCWS209sA42X+N5jU2jy6oHTp7ufw8uzUTVFBX9+wTfAlhiJXGS0Bq7X6efruWjuK9Q=="],
"babel-plugin-macros": ["babel-plugin-macros@3.1.0", "", { "dependencies": { "@babel/runtime": "^7.12.5", "cosmiconfig": "^7.0.0", "resolve": "^1.19.0" } }, "sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg=="], "babel-plugin-macros": ["babel-plugin-macros@3.1.0", "", { "dependencies": { "@babel/runtime": "^7.12.5", "cosmiconfig": "^7.0.0", "resolve": "^1.19.0" } }, "sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg=="],
@@ -1657,7 +1656,7 @@
"protocol-buffers-schema": ["protocol-buffers-schema@3.6.0", "", {}, "sha512-TdDRD+/QNdrCGCE7v8340QyuXd4kIWIgapsE2+n/SaGiSSbomYl4TjHlvIoCWRpE7wFt02EpB35VVA2ImcBVqw=="], "protocol-buffers-schema": ["protocol-buffers-schema@3.6.0", "", {}, "sha512-TdDRD+/QNdrCGCE7v8340QyuXd4kIWIgapsE2+n/SaGiSSbomYl4TjHlvIoCWRpE7wFt02EpB35VVA2ImcBVqw=="],
"proxy-from-env": ["proxy-from-env@1.1.0", "", {}, "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg=="], "proxy-from-env": ["proxy-from-env@2.1.0", "", {}, "sha512-cJ+oHTW1VAEa8cJslgmUZrc+sjRKgAKl3Zyse6+PV38hZe/V6Z14TbCuXcan9F9ghlz4QrFr2c92TNF82UkYHA=="],
"punycode": ["punycode@2.3.1", "", {}, "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg=="], "punycode": ["punycode@2.3.1", "", {}, "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg=="],
+1 -1
View File
@@ -10,7 +10,7 @@
"@visactor/react-vchart": "~1.8.8", "@visactor/react-vchart": "~1.8.8",
"@visactor/vchart": "~1.8.8", "@visactor/vchart": "~1.8.8",
"@visactor/vchart-semi-theme": "~1.8.8", "@visactor/vchart-semi-theme": "~1.8.8",
"axios": "1.13.5", "axios": "1.15.0",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"dayjs": "^1.11.11", "dayjs": "^1.11.11",
"history": "^5.3.0", "history": "^5.3.0",
@@ -21,8 +21,9 @@ import React, { useRef, useEffect } from 'react';
import { Typography, TextArea, Button } from '@douyinfe/semi-ui'; import { Typography, TextArea, Button } from '@douyinfe/semi-ui';
import MarkdownRenderer from '../common/markdown/MarkdownRenderer'; import MarkdownRenderer from '../common/markdown/MarkdownRenderer';
import ThinkingContent from './ThinkingContent'; import ThinkingContent from './ThinkingContent';
import { Loader2, Check, X } from 'lucide-react'; import { Loader2, Check, X, Settings, AlertTriangle } from 'lucide-react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { isAdmin } from '../../helpers/utils';
const MessageContent = ({ const MessageContent = ({
message, message,
@@ -64,6 +65,44 @@ const MessageContent = ({
errorText = t('请求发生错误'); errorText = t('请求发生错误');
} }
if (message.errorCode === 'model_price_error') {
return (
<div className={`${className}`}>
<div
className='rounded-lg p-3 space-y-2'
style={{
background: 'var(--semi-color-bg-0)',
border: '1px solid var(--semi-color-border)',
}}
>
<div className='flex items-center gap-2'>
<AlertTriangle size={16} className='text-orange-500 shrink-0' />
<Typography.Text strong className='!text-[var(--semi-color-text-0)]'>
{t('模型价格未配置')}
</Typography.Text>
</div>
<Typography.Paragraph
className='!text-[var(--semi-color-text-1)] !text-sm !mb-0'
style={{ wordBreak: 'break-word' }}
>
{errorText}
</Typography.Paragraph>
{isAdmin() && (
<Button
size='small'
theme='light'
type='warning'
icon={<Settings size={14} />}
onClick={() => window.open('/console/setting?tab=ratio', '_blank')}
>
{t('前往设置')}
</Button>
)}
</div>
</div>
);
}
return ( return (
<div className={`${className}`}> <div className={`${className}`}>
<Typography.Text className='text-white'>{errorText}</Typography.Text> <Typography.Text className='text-white'>{errorText}</Typography.Text>
+75 -15
View File
@@ -18,12 +18,13 @@ For commercial licensing, please contact support@quantumnous.com
*/ */
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Card, Spin } from '@douyinfe/semi-ui'; import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import SettingsGeneralPayment from '../../pages/Setting/Payment/SettingsGeneralPayment'; import SettingsGeneralPayment from '../../pages/Setting/Payment/SettingsGeneralPayment';
import SettingsPaymentGateway from '../../pages/Setting/Payment/SettingsPaymentGateway'; import SettingsPaymentGateway from '../../pages/Setting/Payment/SettingsPaymentGateway';
import SettingsPaymentGatewayStripe from '../../pages/Setting/Payment/SettingsPaymentGatewayStripe'; import SettingsPaymentGatewayStripe from '../../pages/Setting/Payment/SettingsPaymentGatewayStripe';
import SettingsPaymentGatewayCreem from '../../pages/Setting/Payment/SettingsPaymentGatewayCreem'; import SettingsPaymentGatewayCreem from '../../pages/Setting/Payment/SettingsPaymentGatewayCreem';
import SettingsPaymentGatewayWaffo from '../../pages/Setting/Payment/SettingsPaymentGatewayWaffo'; import SettingsPaymentGatewayWaffo from '../../pages/Setting/Payment/SettingsPaymentGatewayWaffo';
import SettingsPaymentGatewayWaffoPancake from '../../pages/Setting/Payment/SettingsPaymentGatewayWaffoPancake';
import { API, showError, toBoolean } from '../../helpers'; import { API, showError, toBoolean } from '../../helpers';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@@ -48,6 +49,17 @@ const PaymentSetting = () => {
StripeUnitPrice: 8.0, StripeUnitPrice: 8.0,
StripeMinTopUp: 1, StripeMinTopUp: 1,
StripePromotionCodesEnabled: false, StripePromotionCodesEnabled: false,
WaffoPancakeEnabled: false,
WaffoPancakeSandbox: false,
WaffoPancakeMerchantID: '',
WaffoPancakePrivateKey: '',
WaffoPancakeStoreID: '',
WaffoPancakeProductID: '',
WaffoPancakeReturnURL: '',
WaffoPancakeCurrency: 'USD',
WaffoPancakeUnitPrice: 1.0,
WaffoPancakeMinTopUp: 1,
}); });
let [loading, setLoading] = useState(false); let [loading, setLoading] = useState(false);
@@ -96,8 +108,21 @@ const PaymentSetting = () => {
case 'MinTopUp': case 'MinTopUp':
case 'StripeUnitPrice': case 'StripeUnitPrice':
case 'StripeMinTopUp': case 'StripeMinTopUp':
case 'WaffoPancakeUnitPrice':
case 'WaffoPancakeMinTopUp':
newInputs[item.key] = parseFloat(item.value); newInputs[item.key] = parseFloat(item.value);
break; break;
case 'WaffoPancakeMerchantID':
case 'WaffoPancakePrivateKey':
case 'WaffoPancakeStoreID':
case 'WaffoPancakeProductID':
case 'WaffoPancakeReturnURL':
case 'WaffoPancakeCurrency':
newInputs[item.key] = item.value;
break;
case 'WaffoPancakeSandbox':
newInputs[item.key] = toBoolean(item.value);
break;
default: default:
if (item.key.endsWith('Enabled')) { if (item.key.endsWith('Enabled')) {
newInputs[item.key] = toBoolean(item.value); newInputs[item.key] = toBoolean(item.value);
@@ -108,7 +133,7 @@ const PaymentSetting = () => {
} }
}); });
setInputs(newInputs); setInputs((prev) => ({ ...prev, ...newInputs }));
} else { } else {
showError(t(message)); showError(t(message));
} }
@@ -133,19 +158,54 @@ const PaymentSetting = () => {
<> <>
<Spin spinning={loading} size='large'> <Spin spinning={loading} size='large'>
<Card style={{ marginTop: '10px' }}> <Card style={{ marginTop: '10px' }}>
<SettingsGeneralPayment options={inputs} refresh={onRefresh} /> <Tabs
</Card> type='card'
<Card style={{ marginTop: '10px' }}> defaultActiveKey='general'
<SettingsPaymentGateway options={inputs} refresh={onRefresh} /> contentStyle={{ paddingTop: 24 }}
</Card> >
<Card style={{ marginTop: '10px' }}> <Tabs.TabPane tab={t('通用设置')} itemKey='general'>
<SettingsPaymentGatewayStripe options={inputs} refresh={onRefresh} /> <SettingsGeneralPayment
</Card> options={inputs}
<Card style={{ marginTop: '10px' }}> refresh={onRefresh}
<SettingsPaymentGatewayCreem options={inputs} refresh={onRefresh} /> hideSectionTitle
</Card> />
<Card style={{ marginTop: '10px' }}> </Tabs.TabPane>
<SettingsPaymentGatewayWaffo options={inputs} refresh={onRefresh} /> <Tabs.TabPane tab={t('易支付设置')} itemKey='epay'>
<SettingsPaymentGateway
options={inputs}
refresh={onRefresh}
hideSectionTitle
/>
</Tabs.TabPane>
<Tabs.TabPane tab={t('Stripe 设置')} itemKey='stripe'>
<SettingsPaymentGatewayStripe
options={inputs}
refresh={onRefresh}
hideSectionTitle
/>
</Tabs.TabPane>
<Tabs.TabPane tab={t('Creem 设置')} itemKey='creem'>
<SettingsPaymentGatewayCreem
options={inputs}
refresh={onRefresh}
hideSectionTitle
/>
</Tabs.TabPane>
<Tabs.TabPane tab={t('Waffo 设置')} itemKey='waffo'>
<SettingsPaymentGatewayWaffo
options={inputs}
refresh={onRefresh}
hideSectionTitle
/>
</Tabs.TabPane>
{/*<Tabs.TabPane tab={t('Waffo Pancake 设置')} itemKey='waffo-pancake'>*/}
{/* <SettingsPaymentGatewayWaffoPancake*/}
{/* options={inputs}*/}
{/* refresh={onRefresh}*/}
{/* hideSectionTitle*/}
{/* />*/}
{/*</Tabs.TabPane>*/}
</Tabs>
</Card> </Card>
</Spin> </Spin>
</> </>
+125 -20
View File
@@ -45,6 +45,8 @@ import EmailBindModal from './personal/modals/EmailBindModal';
import WeChatBindModal from './personal/modals/WeChatBindModal'; import WeChatBindModal from './personal/modals/WeChatBindModal';
import AccountDeleteModal from './personal/modals/AccountDeleteModal'; import AccountDeleteModal from './personal/modals/AccountDeleteModal';
import ChangePasswordModal from './personal/modals/ChangePasswordModal'; import ChangePasswordModal from './personal/modals/ChangePasswordModal';
import SecureVerificationModal from '../common/modals/SecureVerificationModal';
import { useSecureVerification } from '../../hooks/common/useSecureVerification';
const PersonalSetting = () => { const PersonalSetting = () => {
const [userState, userDispatch] = useContext(UserContext); const [userState, userDispatch] = useContext(UserContext);
@@ -76,6 +78,10 @@ const PersonalSetting = () => {
const [passkeyRegisterLoading, setPasskeyRegisterLoading] = useState(false); const [passkeyRegisterLoading, setPasskeyRegisterLoading] = useState(false);
const [passkeyDeleteLoading, setPasskeyDeleteLoading] = useState(false); const [passkeyDeleteLoading, setPasskeyDeleteLoading] = useState(false);
const [passkeySupported, setPasskeySupported] = useState(false); const [passkeySupported, setPasskeySupported] = useState(false);
const [
passkeyRequiredVerificationMethod,
setPasskeyRequiredVerificationMethod,
] = useState(null);
const [notificationSettings, setNotificationSettings] = useState({ const [notificationSettings, setNotificationSettings] = useState({
warningType: 'email', warningType: 'email',
warningThreshold: 100000, warningThreshold: 100000,
@@ -91,6 +97,34 @@ const PersonalSetting = () => {
recordIpLog: false, recordIpLog: false,
}); });
const {
isModalVisible: isPasskeyVerificationModalVisible,
verificationMethods: passkeyVerificationMethods,
verificationState: passkeyVerificationState,
startVerification: startPasskeyVerification,
executeVerification: executePasskeyVerification,
cancelVerification: cancelPasskeyVerification,
setVerificationCode: setPasskeyVerificationCode,
switchVerificationMethod: switchPasskeyVerificationMethod,
checkVerificationMethods: checkPasskeyVerificationMethods,
} = useSecureVerification({
onSuccess: () => {
setPasskeyRequiredVerificationMethod(null);
},
});
const visiblePasskeyVerificationMethods = passkeyRequiredVerificationMethod
? {
...passkeyVerificationMethods,
has2FA:
passkeyRequiredVerificationMethod === '2fa' &&
passkeyVerificationMethods.has2FA,
hasPasskey:
passkeyRequiredVerificationMethod === 'passkey' &&
passkeyVerificationMethods.hasPasskey,
}
: passkeyVerificationMethods;
useEffect(() => { useEffect(() => {
let saved = localStorage.getItem('status'); let saved = localStorage.getItem('status');
if (saved) { if (saved) {
@@ -203,18 +237,57 @@ const PersonalSetting = () => {
} }
}; };
const handleRegisterPasskey = async () => { const startPasskeyManagementVerification = async (apiCall, options = {}) => {
if (!passkeySupported || !window.PublicKeyCredential) { const methods = await checkPasskeyVerificationMethods();
const requiredMethod = methods.has2FA
? '2fa'
: methods.hasPasskey
? 'passkey'
: null;
if (!requiredMethod) {
showError(t('您需要先启用两步验证或 Passkey 才能执行此操作'));
return;
}
if (requiredMethod === 'passkey' && !methods.passkeySupported) {
showInfo(t('当前设备不支持 Passkey')); showInfo(t('当前设备不支持 Passkey'));
return; return;
} }
setPasskeyRequiredVerificationMethod(requiredMethod);
await startPasskeyVerification(apiCall, {
preferredMethod: requiredMethod,
title: t('安全验证'),
...options,
});
};
const startPasskeyRegistration = async () => {
const methods = await checkPasskeyVerificationMethods();
if (!methods.has2FA) {
try {
await registerPasskey();
} catch (error) {
showError(error.message || t('Passkey 注册失败,请重试'));
}
return;
}
setPasskeyRequiredVerificationMethod('2fa');
await startPasskeyVerification(registerPasskey, {
preferredMethod: '2fa',
title: t('安全验证'),
});
};
const registerPasskey = async () => {
setPasskeyRegisterLoading(true); setPasskeyRegisterLoading(true);
try { try {
const beginRes = await API.post('/api/user/passkey/register/begin'); const beginRes = await API.post('/api/user/passkey/register/begin');
const { success, message, data } = beginRes.data; const { success, message, data } = beginRes.data;
if (!success) { if (!success) {
showError(message || t('无法发起 Passkey 注册')); throw new Error(message || t('无法发起 Passkey 注册'));
return;
} }
const publicKey = prepareCredentialCreationOptions( const publicKey = prepareCredentialCreationOptions(
@@ -223,49 +296,69 @@ const PersonalSetting = () => {
const credential = await navigator.credentials.create({ publicKey }); const credential = await navigator.credentials.create({ publicKey });
const payload = buildRegistrationResult(credential); const payload = buildRegistrationResult(credential);
if (!payload) { if (!payload) {
showError(t('Passkey 注册失败,请重试')); throw new Error(t('Passkey 注册失败,请重试'));
return;
} }
const finishRes = await API.post( const finishRes = await API.post(
'/api/user/passkey/register/finish', '/api/user/passkey/register/finish',
payload, payload,
); );
if (finishRes.data.success) { if (!finishRes.data.success) {
showSuccess(t('Passkey 注册成功')); throw new Error(
await loadPasskeyStatus(); finishRes.data.message || t('Passkey 注册失败,请重试'),
} else { );
showError(finishRes.data.message || t('Passkey 注册失败,请重试'));
} }
showSuccess(t('Passkey 注册成功'));
await loadPasskeyStatus();
return finishRes.data;
} catch (error) { } catch (error) {
if (error?.name === 'AbortError') { if (error?.name === 'AbortError') {
showInfo(t('已取消 Passkey 注册')); showInfo(t('已取消 Passkey 注册'));
} else { return { cancelled: true };
showError(t('Passkey 注册失败,请重试'));
} }
throw new Error(error?.message || t('Passkey 注册失败,请重试'));
} finally { } finally {
setPasskeyRegisterLoading(false); setPasskeyRegisterLoading(false);
} }
}; };
const handleRemovePasskey = async () => { const handleRegisterPasskey = async () => {
if (!passkeySupported || !window.PublicKeyCredential) {
showInfo(t('当前设备不支持 Passkey'));
return;
}
await startPasskeyRegistration();
};
const removePasskey = async () => {
setPasskeyDeleteLoading(true); setPasskeyDeleteLoading(true);
try { try {
const res = await API.delete('/api/user/passkey'); const res = await API.delete('/api/user/passkey');
const { success, message } = res.data; const { success, message } = res.data;
if (success) { if (!success) {
showSuccess(t('Passkey 已解绑')); throw new Error(message || t('操作失败,请重试'));
await loadPasskeyStatus();
} else {
showError(message || t('操作失败,请重试'));
} }
showSuccess(t('Passkey 已解绑'));
await loadPasskeyStatus();
return res.data;
} catch (error) { } catch (error) {
showError(t('操作失败,请重试')); throw new Error(error?.message || t('操作失败,请重试'));
} finally { } finally {
setPasskeyDeleteLoading(false); setPasskeyDeleteLoading(false);
} }
}; };
const handleRemovePasskey = async () => {
await startPasskeyManagementVerification(removePasskey);
};
const handlePasskeyVerificationCancel = () => {
setPasskeyRequiredVerificationMethod(null);
cancelPasskeyVerification();
};
const getUserData = async () => { const getUserData = async () => {
let res = await API.get(`/api/user/self`); let res = await API.get(`/api/user/self`);
const { success, message, data } = res.data; const { success, message, data } = res.data;
@@ -556,6 +649,18 @@ const PersonalSetting = () => {
turnstileSiteKey={turnstileSiteKey} turnstileSiteKey={turnstileSiteKey}
setTurnstileToken={setTurnstileToken} setTurnstileToken={setTurnstileToken}
/> />
<SecureVerificationModal
visible={isPasskeyVerificationModalVisible}
verificationMethods={visiblePasskeyVerificationMethods}
verificationState={passkeyVerificationState}
onVerify={executePasskeyVerification}
onCancel={handlePasskeyVerificationCancel}
onCodeChange={setPasskeyVerificationCode}
onMethodSwitch={switchPasskeyVerificationMethod}
title={passkeyVerificationState.title}
description={passkeyVerificationState.description}
/>
</div> </div>
); );
}; };
@@ -29,6 +29,7 @@ import {
Collapse, Collapse,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { API, showError } from '../../../../helpers'; import { API, showError } from '../../../../helpers';
import { MOBILE_BREAKPOINT } from '../../../../hooks/common/useIsMobile';
const { Text } = Typography; const { Text } = Typography;
@@ -98,10 +99,12 @@ const resolveRateLimitWindows = (data) => {
} }
if (!fiveHourWindow) { if (!fiveHourWindow) {
fiveHourWindow = windows.find((windowData) => windowData !== weeklyWindow) ?? null; fiveHourWindow =
windows.find((windowData) => windowData !== weeklyWindow) ?? null;
} }
if (!weeklyWindow) { if (!weeklyWindow) {
weeklyWindow = windows.find((windowData) => windowData !== fiveHourWindow) ?? null; weeklyWindow =
windows.find((windowData) => windowData !== fiveHourWindow) ?? null;
} }
return { fiveHourWindow, weeklyWindow }; return { fiveHourWindow, weeklyWindow };
@@ -135,6 +138,40 @@ const getDisplayText = (value) => {
return String(value).trim(); return String(value).trim();
}; };
const isMobileViewport = () =>
typeof window !== 'undefined' && window.innerWidth < MOBILE_BREAKPOINT;
const getCodexUsageModalLayout = () => {
if (isMobileViewport()) {
return {
width: 'calc(100vw - 16px)',
style: {
top: 8,
maxWidth: 'calc(100vw - 16px)',
margin: '0 auto',
},
bodyStyle: {
maxHeight: 'calc(100vh - 148px)',
overflowY: 'auto',
padding: '16px 16px 12px',
},
};
}
return {
width: 900,
style: {
top: 24,
maxWidth: 'min(900px, 92vw)',
},
bodyStyle: {
maxHeight: 'calc(100vh - 172px)',
overflowY: 'auto',
padding: '20px 24px 16px',
},
};
};
const formatAccountTypeLabel = (value, t) => { const formatAccountTypeLabel = (value, t) => {
const tt = typeof t === 'function' ? t : (v) => v; const tt = typeof t === 'function' ? t : (v) => v;
const normalized = normalizePlanType(value); const normalized = normalizePlanType(value);
@@ -224,7 +261,7 @@ const RateLimitWindowCard = ({ t, title, windowData }) => {
return ( return (
<div className='rounded-lg border border-semi-color-border bg-semi-color-bg-0 p-3'> <div className='rounded-lg border border-semi-color-border bg-semi-color-bg-0 p-3'>
<div className='flex items-center justify-between gap-2'> <div className='flex flex-wrap items-start justify-between gap-x-3 gap-y-1'>
<div className='font-medium'>{title}</div> <div className='font-medium'>{title}</div>
<Text type='tertiary' size='small'> <Text type='tertiary' size='small'>
{tt('重置时间:')} {tt('重置时间:')}
@@ -262,12 +299,86 @@ const RateLimitWindowCard = ({ t, title, windowData }) => {
); );
}; };
const RateLimitWindowGrid = ({ t, fiveHourWindow, weeklyWindow }) => {
const tt = typeof t === 'function' ? t : (v) => v;
return (
<div className='grid grid-cols-1 gap-3 md:grid-cols-2'>
<RateLimitWindowCard
t={tt}
title={tt('5小时窗口')}
windowData={fiveHourWindow}
/>
<RateLimitWindowCard
t={tt}
title={tt('每周窗口')}
windowData={weeklyWindow}
/>
</div>
);
};
const RateLimitGroupSection = ({
t,
title,
description,
rateLimitSource,
statusTag,
meteredFeature,
}) => {
const tt = typeof t === 'function' ? t : (v) => v;
const { fiveHourWindow, weeklyWindow } =
resolveRateLimitWindows(rateLimitSource);
const featureText = getDisplayText(meteredFeature);
return (
<section className='space-y-3'>
<div className='flex flex-wrap items-start justify-between gap-3'>
<div className='min-w-0 space-y-2'>
<div className='flex flex-wrap items-center gap-2'>
<div className='text-sm font-semibold text-semi-color-text-0'>
{title}
</div>
{statusTag}
</div>
{(description || featureText) && (
<div className='flex flex-wrap items-center gap-2 text-xs text-semi-color-text-2'>
{description ? <span>{description}</span> : null}
{featureText ? (
<div className='inline-flex max-w-full items-center gap-2 rounded-full bg-semi-color-fill-0 px-2 py-1'>
<span className='text-[11px] text-semi-color-text-2'>
metered_feature
</span>
<span className='min-w-0 break-all font-mono text-xs text-semi-color-text-0'>
{featureText}
</span>
</div>
) : null}
</div>
)}
</div>
</div>
<RateLimitWindowGrid
t={tt}
fiveHourWindow={fiveHourWindow}
weeklyWindow={weeklyWindow}
/>
</section>
);
};
const CodexUsageView = ({ t, record, payload, onCopy, onRefresh }) => { const CodexUsageView = ({ t, record, payload, onCopy, onRefresh }) => {
const tt = typeof t === 'function' ? t : (v) => v; const tt = typeof t === 'function' ? t : (v) => v;
const [showRawJson, setShowRawJson] = useState(false); const [showRawJson, setShowRawJson] = useState(false);
const data = payload?.data ?? null; const data = payload?.data ?? null;
const rateLimit = data?.rate_limit ?? {}; const rateLimit = data?.rate_limit ?? {};
const { fiveHourWindow, weeklyWindow } = resolveRateLimitWindows(data); const additionalRateLimits = Array.isArray(data?.additional_rate_limits)
? data.additional_rate_limits.filter(
(item) =>
item && typeof item === 'object' && Object.keys(item).length > 0,
)
: [];
const upstreamStatus = payload?.upstream_status; const upstreamStatus = payload?.upstream_status;
const accountType = data?.plan_type ?? rateLimit?.plan_type; const accountType = data?.plan_type ?? rateLimit?.plan_type;
const accountTypeLabel = formatAccountTypeLabel(accountType, tt); const accountTypeLabel = formatAccountTypeLabel(accountType, tt);
@@ -277,7 +388,9 @@ const CodexUsageView = ({ t, record, payload, onCopy, onRefresh }) => {
const email = data?.email; const email = data?.email;
const accountId = data?.account_id; const accountId = data?.account_id;
const errorMessage = const errorMessage =
payload?.success === false ? getDisplayText(payload?.message) || tt('获取用量失败') : ''; payload?.success === false
? getDisplayText(payload?.message) || tt('获取用量失败')
: '';
const rawText = const rawText =
typeof data === 'string' ? data : JSON.stringify(data ?? payload, null, 2); typeof data === 'string' ? data : JSON.stringify(data ?? payload, null, 2);
@@ -313,7 +426,12 @@ const CodexUsageView = ({ t, record, payload, onCopy, onRefresh }) => {
</Tag> </Tag>
</div> </div>
</div> </div>
<Button size='small' type='tertiary' theme='outline' onClick={onRefresh}> <Button
size='small'
type='tertiary'
theme='outline'
onClick={onRefresh}
>
{tt('刷新')} {tt('刷新')}
</Button> </Button>
</div> </div>
@@ -355,22 +473,61 @@ const CodexUsageView = ({ t, record, payload, onCopy, onRefresh }) => {
{tt('额度窗口')} {tt('额度窗口')}
</div> </div>
<Text type='tertiary' size='small'> <Text type='tertiary' size='small'>
{tt('用于观察当前帐号在 Codex 上游的限额使用情况')} {tt(
'用于观察当前帐号在 Codex 上游的基础限额与附加计费能力使用情况',
)}
</Text> </Text>
</div> </div>
</div> </div>
<div className='grid grid-cols-1 gap-3 md:grid-cols-2'> <div className='space-y-5'>
<RateLimitWindowCard <RateLimitGroupSection
t={tt} t={tt}
title={tt('5小时窗口')} title={tt('基础额度')}
windowData={fiveHourWindow} description={tt('当前帐号的基础额度窗口')}
/> rateLimitSource={data}
<RateLimitWindowCard statusTag={statusTag}
t={tt}
title={tt('每周窗口')}
windowData={weeklyWindow}
/> />
{additionalRateLimits.length > 0 ? (
<div className='space-y-4 border-t border-semi-color-border pt-4'>
<div>
<div className='text-sm font-semibold text-semi-color-text-0'>
{tt('附加额度')}
</div>
<Text type='tertiary' size='small'>
{tt('按模型或能力拆分的附加计费能力窗口')}
</Text>
</div>
<div className='space-y-4'>
{additionalRateLimits.map((item, index) => {
const limitName =
getDisplayText(item?.limit_name) ||
getDisplayText(item?.metered_feature) ||
`${tt('附加额度')} ${index + 1}`;
return (
<div
key={`${limitName}-${getDisplayText(item?.metered_feature)}-${index}`}
className={
index > 0 ? 'border-t border-semi-color-border pt-4' : ''
}
>
<RateLimitGroupSection
t={tt}
title={limitName}
description={tt('附加计费能力')}
rateLimitSource={item}
statusTag={resolveUsageStatusTag(tt, item?.rate_limit)}
meteredFeature={item?.metered_feature}
/>
</div>
);
})}
</div>
</div>
) : null}
</div> </div>
<Collapse <Collapse
@@ -489,12 +646,14 @@ const CodexUsageLoader = ({ t, record, initialPayload, onCopy }) => {
export const openCodexUsageModal = ({ t, record, payload, onCopy }) => { export const openCodexUsageModal = ({ t, record, payload, onCopy }) => {
const tt = typeof t === 'function' ? t : (v) => v; const tt = typeof t === 'function' ? t : (v) => v;
const layout = getCodexUsageModalLayout();
Modal.info({ Modal.info({
title: tt('Codex 帐号与用量'), title: tt('Codex 帐号与用量'),
centered: true, centered: false,
width: 900, width: layout.width,
style: { maxWidth: '95vw' }, style: layout.style,
bodyStyle: layout.bodyStyle,
content: ( content: (
<CodexUsageLoader <CodexUsageLoader
t={tt} t={tt}
@@ -208,6 +208,7 @@ const EditChannelModal = (props) => {
allow_safety_identifier: false, allow_safety_identifier: false,
allow_include_obfuscation: false, allow_include_obfuscation: false,
allow_inference_geo: false, allow_inference_geo: false,
allow_speed: false,
claude_beta_query: false, claude_beta_query: false,
upstream_model_update_check_enabled: false, upstream_model_update_check_enabled: false,
upstream_model_update_auto_sync_enabled: false, upstream_model_update_auto_sync_enabled: false,
@@ -890,6 +891,7 @@ const EditChannelModal = (props) => {
parsedSettings.allow_include_obfuscation || false; parsedSettings.allow_include_obfuscation || false;
data.allow_inference_geo = data.allow_inference_geo =
parsedSettings.allow_inference_geo || false; parsedSettings.allow_inference_geo || false;
data.allow_speed = parsedSettings.allow_speed || false;
data.claude_beta_query = parsedSettings.claude_beta_query || false; data.claude_beta_query = parsedSettings.claude_beta_query || false;
data.upstream_model_update_check_enabled = data.upstream_model_update_check_enabled =
parsedSettings.upstream_model_update_check_enabled === true; parsedSettings.upstream_model_update_check_enabled === true;
@@ -919,6 +921,7 @@ const EditChannelModal = (props) => {
data.allow_safety_identifier = false; data.allow_safety_identifier = false;
data.allow_include_obfuscation = false; data.allow_include_obfuscation = false;
data.allow_inference_geo = false; data.allow_inference_geo = false;
data.allow_speed = false;
data.claude_beta_query = false; data.claude_beta_query = false;
data.upstream_model_update_check_enabled = false; data.upstream_model_update_check_enabled = false;
data.upstream_model_update_auto_sync_enabled = false; data.upstream_model_update_auto_sync_enabled = false;
@@ -936,6 +939,7 @@ const EditChannelModal = (props) => {
data.allow_safety_identifier = false; data.allow_safety_identifier = false;
data.allow_include_obfuscation = false; data.allow_include_obfuscation = false;
data.allow_inference_geo = false; data.allow_inference_geo = false;
data.allow_speed = false;
data.claude_beta_query = false; data.claude_beta_query = false;
data.upstream_model_update_check_enabled = false; data.upstream_model_update_check_enabled = false;
data.upstream_model_update_auto_sync_enabled = false; data.upstream_model_update_auto_sync_enabled = false;
@@ -1776,6 +1780,7 @@ const EditChannelModal = (props) => {
} }
if (localInputs.type === 14) { if (localInputs.type === 14) {
settings.allow_inference_geo = localInputs.allow_inference_geo === true; settings.allow_inference_geo = localInputs.allow_inference_geo === true;
settings.allow_speed = localInputs.allow_speed === true;
settings.claude_beta_query = localInputs.claude_beta_query === true; settings.claude_beta_query = localInputs.claude_beta_query === true;
} }
} }
@@ -1823,6 +1828,7 @@ const EditChannelModal = (props) => {
delete localInputs.allow_safety_identifier; delete localInputs.allow_safety_identifier;
delete localInputs.allow_include_obfuscation; delete localInputs.allow_include_obfuscation;
delete localInputs.allow_inference_geo; delete localInputs.allow_inference_geo;
delete localInputs.allow_speed;
delete localInputs.claude_beta_query; delete localInputs.claude_beta_query;
delete localInputs.upstream_model_update_check_enabled; delete localInputs.upstream_model_update_check_enabled;
delete localInputs.upstream_model_update_auto_sync_enabled; delete localInputs.upstream_model_update_auto_sync_enabled;
@@ -2480,6 +2486,7 @@ const EditChannelModal = (props) => {
</div> </div>
<Form.Switch field='allow_service_tier' label={t('允许 service_tier 透传')} checkedText={t('开')} uncheckedText={t('关')} onChange={(value) => handleChannelOtherSettingsChange('allow_service_tier', value)} extraText={t('service_tier 字段用于指定服务层级,允许透传可能导致实际计费高于预期。默认关闭以避免额外费用')} /> <Form.Switch field='allow_service_tier' label={t('允许 service_tier 透传')} checkedText={t('开')} uncheckedText={t('关')} onChange={(value) => handleChannelOtherSettingsChange('allow_service_tier', value)} extraText={t('service_tier 字段用于指定服务层级,允许透传可能导致实际计费高于预期。默认关闭以避免额外费用')} />
<Form.Switch field='allow_inference_geo' label={t('允许 inference_geo 透传')} checkedText={t('开')} uncheckedText={t('关')} onChange={(value) => handleChannelOtherSettingsChange('allow_inference_geo', value)} extraText={t('inference_geo 字段用于控制 Claude 数据驻留推理区域。默认关闭以避免未经授权透传地域信息')} /> <Form.Switch field='allow_inference_geo' label={t('允许 inference_geo 透传')} checkedText={t('开')} uncheckedText={t('关')} onChange={(value) => handleChannelOtherSettingsChange('allow_inference_geo', value)} extraText={t('inference_geo 字段用于控制 Claude 数据驻留推理区域。默认关闭以避免未经授权透传地域信息')} />
<Form.Switch field='allow_speed' label={t('允许 speed 透传')} checkedText={t('开')} uncheckedText={t('关')} onChange={(value) => handleChannelOtherSettingsChange('allow_speed', value)} extraText={t('speed 字段用于控制 Claude 推理速度模式。默认关闭以避免意外切换到 fast 模式')} />
</> </>
)} )}
</div> </div>
@@ -30,6 +30,7 @@ import {
Banner, Banner,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons'; import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons';
import { Settings } from 'lucide-react';
import { copy, showError, showInfo, showSuccess } from '../../../../helpers'; import { copy, showError, showInfo, showSuccess } from '../../../../helpers';
import { MODEL_TABLE_PAGE_SIZE } from '../../../../constants'; import { MODEL_TABLE_PAGE_SIZE } from '../../../../constants';
@@ -168,17 +169,43 @@ const ModelTestModal = ({
} }
return ( return (
<div className='flex items-center gap-2'> <div className='flex flex-col gap-1'>
<Tag color={testResult.success ? 'green' : 'red'} shape='circle'> <div className='flex items-center gap-2'>
{testResult.success ? t('成功') : t('失败')} <Tag color={testResult.success ? 'green' : 'red'} shape='circle'>
</Tag> {testResult.success ? t('成功') : t('失败')}
{testResult.success && ( </Tag>
<Typography.Text type='tertiary'> {testResult.success && (
{t('请求时长: ${time}s').replace( <Typography.Text type='tertiary'>
'${time}', {t('请求时长: ${time}s').replace(
testResult.time.toFixed(2), '${time}',
testResult.time.toFixed(2),
)}
</Typography.Text>
)}
</div>
{!testResult.success && testResult.message && (
<div className='flex flex-col gap-1'>
<Typography.Text
type='danger'
size='small'
className='break-all'
style={{ maxWidth: '400px', fontSize: '12px' }}
>
{testResult.message}
</Typography.Text>
{testResult.errorCode === 'model_price_error' && (
<Button
size='small'
theme='light'
type='warning'
icon={<Settings size={12} />}
onClick={() => window.open('/console/setting?tab=ratio', '_blank')}
style={{ width: 'fit-content' }}
>
{t('前往设置')}
</Button>
)} )}
</Typography.Text> </div>
)} )}
</div> </div>
); );
@@ -360,7 +360,7 @@ const MultiKeyManageModal = ({ visible, onCancel, channel, onRefresh }) => {
{ {
title: t('索引'), title: t('索引'),
dataIndex: 'index', dataIndex: 'index',
render: (text) => `#${text}`, render: (text) => `#${Number(text) + 1}`,
}, },
// { // {
// title: t(''), // title: t(''),
@@ -25,8 +25,12 @@ import {
showError, showError,
showSuccess, showSuccess,
renderQuota, renderQuota,
renderQuotaWithPrompt, getCurrencyConfig,
} from '../../../../helpers'; } from '../../../../helpers';
import {
quotaToDisplayAmount,
displayAmountToQuota,
} from '../../../../helpers/quota';
import { useIsMobile } from '../../../../hooks/common/useIsMobile'; import { useIsMobile } from '../../../../hooks/common/useIsMobile';
import { import {
Button, Button,
@@ -41,6 +45,7 @@ import {
Avatar, Avatar,
Row, Row,
Col, Col,
InputNumber,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { import {
IconCreditCard, IconCreditCard,
@@ -57,10 +62,12 @@ const EditRedemptionModal = (props) => {
const [loading, setLoading] = useState(isEdit); const [loading, setLoading] = useState(isEdit);
const isMobile = useIsMobile(); const isMobile = useIsMobile();
const formApiRef = useRef(null); const formApiRef = useRef(null);
const [showQuotaInput, setShowQuotaInput] = useState(false);
const getInitValues = () => ({ const getInitValues = () => ({
name: '', name: '',
quota: 100000, quota: 100000,
amount: Number(quotaToDisplayAmount(100000).toFixed(6)),
count: 1, count: 1,
expired_time: null, expired_time: null,
}); });
@@ -79,6 +86,7 @@ const EditRedemptionModal = (props) => {
} else { } else {
data.expired_time = new Date(data.expired_time * 1000); data.expired_time = new Date(data.expired_time * 1000);
} }
data.amount = Number(quotaToDisplayAmount(data.quota || 0).toFixed(6));
formApiRef.current?.setValues({ ...getInitValues(), ...data }); formApiRef.current?.setValues({ ...getInitValues(), ...data });
} else { } else {
showError(message); showError(message);
@@ -104,7 +112,12 @@ const EditRedemptionModal = (props) => {
setLoading(true); setLoading(true);
let localInputs = { ...values }; let localInputs = { ...values };
localInputs.count = parseInt(localInputs.count) || 0; localInputs.count = parseInt(localInputs.count) || 0;
localInputs.quota = parseInt(localInputs.quota) || 0; localInputs.quota = displayAmountToQuota(localInputs.amount);
if (localInputs.quota <= 0) {
showError(t('请输入金额'));
setLoading(false);
return;
}
localInputs.name = name; localInputs.name = name;
if (!localInputs.expired_time) { if (!localInputs.expired_time) {
localInputs.expired_time = 0; localInputs.expired_time = 0;
@@ -285,37 +298,63 @@ const EditRedemptionModal = (props) => {
</div> </div>
<Row gutter={12}> <Row gutter={12}>
<Col span={12}> <Col span={24}>
<Form.AutoComplete <Form.InputNumber
field='quota' field='amount'
label={t('额')} label={t('额')}
placeholder={t('请输入额度')} prefix={getCurrencyConfig().symbol}
placeholder={t('输入金额')}
precision={6}
min={0}
step={0.000001}
style={{ width: '100%' }} style={{ width: '100%' }}
type='number' onChange={(val) => {
rules={[ const amount = val === '' || val == null ? 0 : val;
{ required: true, message: t('请输入额度') }, formApiRef.current?.setValue('amount', amount);
{ formApiRef.current?.setValue(
validator: (rule, v) => { 'quota',
const num = parseInt(v, 10); displayAmountToQuota(amount),
return num > 0 );
? Promise.resolve() }}
: Promise.reject(t('额度必须大于0'));
},
},
]}
extraText={renderQuotaWithPrompt(
Number(values.quota) || 0,
)}
data={[
{ value: 500000, label: '1$' },
{ value: 5000000, label: '10$' },
{ value: 25000000, label: '50$' },
{ value: 50000000, label: '100$' },
{ value: 250000000, label: '500$' },
{ value: 500000000, label: '1000$' },
]}
showClear showClear
/> />
<div
className='text-xs cursor-pointer mt-1'
style={{ color: 'var(--semi-color-text-2)' }}
onClick={() => setShowQuotaInput((v) => !v)}
>
{showQuotaInput
? `${t('收起原生额度输入')}`
: `${t('使用原生额度输入')}`}
</div>
<div style={{ display: showQuotaInput ? 'block' : 'none' }} className='mt-2'>
<Form.InputNumber
field='quota'
label={t('额度')}
placeholder={t('输入额度')}
rules={[
{ required: true, message: t('请输入额度') },
{
validator: (rule, v) => {
const num = parseInt(v, 10);
return num > 0
? Promise.resolve()
: Promise.reject(t('额度必须大于0'));
},
},
]}
onChange={(val) => {
const quota = val === '' || val == null ? 0 : val;
formApiRef.current?.setValue('quota', quota);
formApiRef.current?.setValue(
'amount',
Number(quotaToDisplayAmount(quota).toFixed(6)),
);
}}
style={{ width: '100%' }}
showClear
/>
</div>
</Col> </Col>
{!isEdit && ( {!isEdit && (
<Col span={12}> <Col span={12}>
@@ -536,6 +536,13 @@ export const getTokensColumns = ({
return <div>{renderTimestamp(text)}</div>; return <div>{renderTimestamp(text)}</div>;
}, },
}, },
{
title: t('最后使用时间'),
dataIndex: 'accessed_time',
render: (text, record, index) => {
return <div>{text ? renderTimestamp(text) : '-'}</div>;
},
},
{ {
title: t('过期时间'), title: t('过期时间'),
dataIndex: 'expired_time', dataIndex: 'expired_time',
@@ -24,10 +24,14 @@ import {
showSuccess, showSuccess,
timestamp2string, timestamp2string,
renderGroupOption, renderGroupOption,
renderQuotaWithPrompt, getCurrencyConfig,
getModelCategories, getModelCategories,
selectFilter, selectFilter,
} from '../../../../helpers'; } from '../../../../helpers';
import {
quotaToDisplayAmount,
displayAmountToQuota,
} from '../../../../helpers/quota';
import { useIsMobile } from '../../../../hooks/common/useIsMobile'; import { useIsMobile } from '../../../../hooks/common/useIsMobile';
import { import {
Button, Button,
@@ -41,6 +45,7 @@ import {
Form, Form,
Col, Col,
Row, Row,
InputNumber,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { import {
IconCreditCard, IconCreditCard,
@@ -62,11 +67,13 @@ const EditTokenModal = (props) => {
const formApiRef = useRef(null); const formApiRef = useRef(null);
const [models, setModels] = useState([]); const [models, setModels] = useState([]);
const [groups, setGroups] = useState([]); const [groups, setGroups] = useState([]);
const [showQuotaInput, setShowQuotaInput] = useState(false);
const isEdit = props.editingToken.id !== undefined; const isEdit = props.editingToken.id !== undefined;
const getInitValues = () => ({ const getInitValues = () => ({
name: '', name: '',
remain_quota: 0, remain_quota: 0,
remain_amount: 0,
expired_time: -1, expired_time: -1,
unlimited_quota: true, unlimited_quota: true,
model_limits_enabled: false, model_limits_enabled: false,
@@ -162,6 +169,9 @@ const EditTokenModal = (props) => {
} else { } else {
data.model_limits = []; data.model_limits = [];
} }
data.remain_amount = Number(
quotaToDisplayAmount(data.remain_quota || 0).toFixed(6),
);
if (formApiRef.current) { if (formApiRef.current) {
formApiRef.current.setValues({ ...getInitValues(), ...data }); formApiRef.current.setValues({ ...getInitValues(), ...data });
} }
@@ -209,7 +219,14 @@ const EditTokenModal = (props) => {
setLoading(true); setLoading(true);
if (isEdit) { if (isEdit) {
let { tokenCount: _tc, ...localInputs } = values; let { tokenCount: _tc, ...localInputs } = values;
localInputs.remain_quota = parseInt(localInputs.remain_quota); localInputs.remain_quota = localInputs.unlimited_quota
? 0
: displayAmountToQuota(localInputs.remain_amount);
if (!localInputs.unlimited_quota && localInputs.remain_quota <= 0) {
showError(t('请输入金额'));
setLoading(false);
return;
}
if (localInputs.expired_time !== -1) { if (localInputs.expired_time !== -1) {
let time = Date.parse(localInputs.expired_time); let time = Date.parse(localInputs.expired_time);
if (isNaN(time)) { if (isNaN(time)) {
@@ -245,7 +262,14 @@ const EditTokenModal = (props) => {
} else { } else {
localInputs.name = baseName; localInputs.name = baseName;
} }
localInputs.remain_quota = parseInt(localInputs.remain_quota); localInputs.remain_quota = localInputs.unlimited_quota
? 0
: displayAmountToQuota(localInputs.remain_amount);
if (!localInputs.unlimited_quota && localInputs.remain_quota <= 0) {
showError(t('请输入金额'));
setLoading(false);
break;
}
if (localInputs.expired_time !== -1) { if (localInputs.expired_time !== -1) {
let time = Date.parse(localInputs.expired_time); let time = Date.parse(localInputs.expired_time);
@@ -497,28 +521,63 @@ const EditTokenModal = (props) => {
</div> </div>
<Row gutter={12}> <Row gutter={12}>
<Col span={24}> <Col span={24}>
<Form.AutoComplete <Form.InputNumber
field='remain_quota' field='remain_amount'
label={t('额')} label={t('额')}
placeholder={t('请输入额度')} prefix={getCurrencyConfig().symbol}
type='number' placeholder={t('输入金额')}
precision={6}
disabled={values.unlimited_quota} disabled={values.unlimited_quota}
extraText={renderQuotaWithPrompt(values.remain_quota)} min={0}
rules={ step={0.000001}
values.unlimited_quota onChange={(val) => {
? [] const amount = val === '' || val == null ? 0 : val;
: [{ required: true, message: t('请输入额度') }] formApiRef.current?.setValue('remain_amount', amount);
} formApiRef.current?.setValue(
data={[ 'remain_quota',
{ value: 500000, label: '1$' }, displayAmountToQuota(amount),
{ value: 5000000, label: '10$' }, );
{ value: 25000000, label: '50$' }, }}
{ value: 50000000, label: '100$' }, style={{ width: '100%' }}
{ value: 250000000, label: '500$' }, showClear
{ value: 500000000, label: '1000$' },
]}
/> />
</Col> </Col>
<Col span={24}>
<div
className='text-xs cursor-pointer mt-1'
style={{ color: 'var(--semi-color-text-2)' }}
onClick={() => setShowQuotaInput((v) => !v)}
>
{showQuotaInput
? `${t('收起原生额度输入')}`
: `${t('使用原生额度输入')}`}
</div>
<div style={{ display: showQuotaInput ? 'block' : 'none' }} className='mt-2'>
<Form.InputNumber
field='remain_quota'
label={t('额度')}
placeholder={t('输入额度')}
disabled={values.unlimited_quota}
min={0}
step={500000}
rules={
values.unlimited_quota
? []
: [{ required: true, message: t('请输入额度') }]
}
onChange={(val) => {
const quota = val === '' || val == null ? 0 : val;
formApiRef.current?.setValue('remain_quota', quota);
formApiRef.current?.setValue(
'remain_amount',
Number(quotaToDisplayAmount(quota).toFixed(6)),
);
}}
style={{ width: '100%' }}
showClear
/>
</div>
</Col>
<Col span={24}> <Col span={24}>
<Form.Switch <Form.Switch
field='unlimited_quota' field='unlimited_quota'
@@ -845,7 +845,12 @@ export const getLogsColumns = ({
), ),
dataIndex: 'ip', dataIndex: 'ip',
render: (text, record, index) => { render: (text, record, index) => {
return (record.type === 2 || record.type === 5) && text ? ( const showIp =
(record.type === 2 ||
record.type === 5 ||
(isAdminUser && record.type === 1)) &&
text;
return showIp ? (
<Tooltip content={text}> <Tooltip content={text}>
<span> <span>
<Tag <Tag
@@ -24,7 +24,6 @@ import {
showError, showError,
showSuccess, showSuccess,
renderQuota, renderQuota,
renderQuotaWithPrompt,
getCurrencyConfig, getCurrencyConfig,
} from '../../../../helpers'; } from '../../../../helpers';
import { import {
@@ -46,6 +45,8 @@ import {
Row, Row,
Col, Col,
InputNumber, InputNumber,
RadioGroup,
Radio,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { import {
IconUser, IconUser,
@@ -53,7 +54,7 @@ import {
IconClose, IconClose,
IconLink, IconLink,
IconUserGroup, IconUserGroup,
IconPlus, IconEdit,
} from '@douyinfe/semi-icons'; } from '@douyinfe/semi-icons';
import UserBindingManagementModal from './UserBindingManagementModal'; import UserBindingManagementModal from './UserBindingManagementModal';
@@ -63,13 +64,18 @@ const EditUserModal = (props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const userId = props.editingUser.id; const userId = props.editingUser.id;
const [loading, setLoading] = useState(true); const [loading, setLoading] = useState(true);
const [addQuotaModalOpen, setIsModalOpen] = useState(false); const [adjustModalOpen, setAdjustModalOpen] = useState(false);
const [addQuotaLocal, setAddQuotaLocal] = useState(''); const [adjustQuotaLocal, setAdjustQuotaLocal] = useState('');
const [addAmountLocal, setAddAmountLocal] = useState(''); const [adjustAmountLocal, setAdjustAmountLocal] = useState('');
const [adjustMode, setAdjustMode] = useState('add');
const [adjustLoading, setAdjustLoading] = useState(false);
const isMobile = useIsMobile(); const isMobile = useIsMobile();
const [groupOptions, setGroupOptions] = useState([]); const [groupOptions, setGroupOptions] = useState([]);
const [bindingModalVisible, setBindingModalVisible] = useState(false); const [bindingModalVisible, setBindingModalVisible] = useState(false);
const formApiRef = useRef(null); const formApiRef = useRef(null);
const [showAdjustQuotaRaw, setShowAdjustQuotaRaw] = useState(false);
const [showQuotaInput, setShowQuotaInput] = useState(false);
const [inputs, setInputs] = useState(null);
const isEdit = Boolean(userId); const isEdit = Boolean(userId);
@@ -85,6 +91,7 @@ const EditUserModal = (props) => {
linux_do_id: '', linux_do_id: '',
email: '', email: '',
quota: 0, quota: 0,
quota_amount: 0,
group: 'default', group: 'default',
remark: '', remark: '',
}); });
@@ -107,13 +114,22 @@ const EditUserModal = (props) => {
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
data.password = ''; data.password = '';
formApiRef.current?.setValues({ ...getInitValues(), ...data }); data.quota_amount = Number(
quotaToDisplayAmount(data.quota || 0).toFixed(6),
);
setInputs({ ...getInitValues(), ...data });
} else { } else {
showError(message); showError(message);
} }
setLoading(false); setLoading(false);
}; };
useEffect(() => {
if (inputs && formApiRef.current) {
formApiRef.current.setValues(inputs);
}
}, [inputs]);
useEffect(() => { useEffect(() => {
loadUser(); loadUser();
if (userId) fetchGroups(); if (userId) fetchGroups();
@@ -132,8 +148,8 @@ const EditUserModal = (props) => {
const submit = async (values) => { const submit = async (values) => {
setLoading(true); setLoading(true);
let payload = { ...values }; let payload = { ...values };
if (typeof payload.quota === 'string') delete payload.quota;
payload.quota = parseInt(payload.quota) || 0; delete payload.quota_amount;
if (userId) { if (userId) {
payload.id = parseInt(userId); payload.id = parseInt(userId);
} }
@@ -150,11 +166,60 @@ const EditUserModal = (props) => {
setLoading(false); setLoading(false);
}; };
/* --------------------- quota helper -------------------- */ /* --------------------- atomic quota adjust -------------------- */
const addLocalQuota = () => { const adjustQuota = async () => {
const current = parseInt(formApiRef.current?.getValue('quota') || 0); const quotaVal = parseInt(adjustQuotaLocal) || 0;
const delta = parseInt(addQuotaLocal) || 0; if (quotaVal <= 0 && adjustMode !== 'override') return;
formApiRef.current?.setValue('quota', current + delta); if (adjustMode === 'override' && (adjustQuotaLocal === '' || adjustQuotaLocal == null)) return;
setAdjustLoading(true);
try {
const res = await API.post('/api/user/manage', {
id: parseInt(userId),
action: 'add_quota',
mode: adjustMode,
value: adjustMode === 'override' ? quotaVal : Math.abs(quotaVal),
});
const { success, message } = res.data;
if (success) {
showSuccess(t('调整额度成功'));
setAdjustModalOpen(false);
setAdjustQuotaLocal('');
setAdjustAmountLocal('');
const userRes = await API.get(`/api/user/${userId}`);
if (userRes.data.success) {
const data = userRes.data.data;
data.password = '';
data.quota_amount = Number(
quotaToDisplayAmount(data.quota || 0).toFixed(6),
);
setInputs({ ...getInitValues(), ...data });
}
props.refresh();
} else {
showError(message);
}
} catch (e) {
showError(e.message);
}
setAdjustLoading(false);
};
const getPreviewText = () => {
const current = formApiRef.current?.getValue('quota') || 0;
const val = parseInt(adjustQuotaLocal) || 0;
let result;
switch (adjustMode) {
case 'add':
result = current + Math.abs(val);
return `${t('当前额度')}${renderQuota(current)}+${renderQuota(Math.abs(val))} = ${renderQuota(result)}`;
case 'subtract':
result = current - Math.abs(val);
return `${t('当前额度')}${renderQuota(current)}-${renderQuota(Math.abs(val))} = ${renderQuota(result)}`;
case 'override':
return `${t('当前额度')}${renderQuota(current)}${renderQuota(val)}`;
default:
return '';
}
}; };
/* --------------------------- UI --------------------------- */ /* --------------------------- UI --------------------------- */
@@ -305,24 +370,47 @@ const EditUserModal = (props) => {
<Col span={10}> <Col span={10}>
<Form.InputNumber <Form.InputNumber
field='quota' field='quota_amount'
label={t('剩余额度')} label={t('金额')}
placeholder={t('请输入新的剩余额度')} prefix={getCurrencyConfig().symbol}
step={500000} precision={6}
extraText={renderQuotaWithPrompt(values.quota || 0)} step={0.000001}
rules={[{ required: true, message: t('请输入额度') }]}
style={{ width: '100%' }} style={{ width: '100%' }}
readonly
/> />
</Col> </Col>
<Col span={14}> <Col span={14}>
<Form.Slot label={t('添加额度')}> <Form.Slot label={t('调整额度')}>
<Button <Button
icon={<IconPlus />} icon={<IconEdit />}
onClick={() => setIsModalOpen(true)} onClick={() => setAdjustModalOpen(true)}
/> >
{t('调整额度')}
</Button>
</Form.Slot> </Form.Slot>
</Col> </Col>
<Col span={24}>
<div
className='text-xs cursor-pointer'
style={{ color: 'var(--semi-color-text-2)' }}
onClick={() => setShowQuotaInput((v) => !v)}
>
{showQuotaInput
? `${t('收起原生额度输入')}`
: `${t('使用原生额度输入')}`}
</div>
<div style={{ display: showQuotaInput ? 'block' : 'none' }} className='mt-2'>
<Form.InputNumber
field='quota'
label={t('额度')}
placeholder={t('请输入额度')}
style={{ width: '100%' }}
readonly
/>
</div>
</Col>
</Row> </Row>
</Card> </Card>
)} )}
@@ -372,81 +460,102 @@ const EditUserModal = (props) => {
formApiRef={formApiRef} formApiRef={formApiRef}
/> />
{/* 添加额度模态框 */} {/* 调整额度模态框 */}
<Modal <Modal
centered centered
visible={addQuotaModalOpen} visible={adjustModalOpen}
onOk={() => { onOk={adjustQuota}
addLocalQuota();
setIsModalOpen(false);
setAddQuotaLocal('');
setAddAmountLocal('');
}}
onCancel={() => { onCancel={() => {
setIsModalOpen(false); setAdjustModalOpen(false);
setAdjustQuotaLocal('');
setAdjustAmountLocal('');
setAdjustMode('add');
}} }}
confirmLoading={adjustLoading}
closable={null} closable={null}
title={ title={
<div className='flex items-center'> <div className='flex items-center'>
<IconPlus className='mr-2' /> <IconEdit className='mr-2' />
{t('添加额度')} {t('调整额度')}
</div> </div>
} }
> >
<div className='mb-4'> <div className='mb-4'>
{(() => { <Text type='secondary' className='block mb-2'>
const current = formApiRef.current?.getValue('quota') || 0; {getPreviewText()}
return ( </Text>
<Text type='secondary' className='block mb-2'>
{`${t('新额度:')}${renderQuota(current)} + ${renderQuota(addQuotaLocal)} = ${renderQuota(current + parseInt(addQuotaLocal || 0))}`}
</Text>
);
})()}
</div> </div>
{getCurrencyConfig().type !== 'TOKENS' && ( <div className='mb-3'>
<div className='mb-3'> <div className='mb-1'>
<div className='mb-1'> <Text size='small'>{t('操作')}</Text>
<Text size='small'>{t('金额')}</Text>
<Text size='small' type='tertiary'>
{' '}
({t('仅用于换算,实际保存的是额度')})
</Text>
</div>
<InputNumber
prefix={getCurrencyConfig().symbol}
placeholder={t('输入金额')}
value={addAmountLocal}
precision={2}
onChange={(val) => {
setAddAmountLocal(val);
setAddQuotaLocal(
val != null && val !== ''
? displayAmountToQuota(Math.abs(val)) * Math.sign(val)
: '',
);
}}
style={{ width: '100%' }}
showClear
/>
</div> </div>
)} <RadioGroup
<div> type='button'
value={adjustMode}
onChange={(e) => {
setAdjustMode(e.target.value);
setAdjustQuotaLocal('');
setAdjustAmountLocal('');
}}
style={{ width: '100%' }}
>
<Radio value='add'>{t('添加')}</Radio>
<Radio value='subtract'>{t('减少')}</Radio>
<Radio value='override'>{t('覆盖')}</Radio>
</RadioGroup>
</div>
<div className='mb-3'>
<div className='mb-1'>
<Text size='small'>{t('金额')}</Text>
</div>
<InputNumber
prefix={getCurrencyConfig().symbol}
placeholder={t('输入金额')}
value={adjustAmountLocal}
precision={6}
min={adjustMode === 'override' ? undefined : 0}
step={0.000001}
onChange={(val) => {
const amount = val === '' || val == null ? '' : val;
setAdjustAmountLocal(amount);
setAdjustQuotaLocal(
amount === ''
? ''
: adjustMode === 'override'
? displayAmountToQuota(amount)
: displayAmountToQuota(Math.abs(amount)),
);
}}
style={{ width: '100%' }}
showClear
/>
</div>
<div
className='text-xs cursor-pointer mt-2'
style={{ color: 'var(--semi-color-text-2)' }}
onClick={() => setShowAdjustQuotaRaw((v) => !v)}
>
{showAdjustQuotaRaw
? `${t('收起原生额度输入')}`
: `${t('使用原生额度输入')}`}
</div>
<div style={{ display: showAdjustQuotaRaw ? 'block' : 'none' }} className='mt-2'>
<div className='mb-1'> <div className='mb-1'>
<Text size='small'>{t('额度')}</Text> <Text size='small'>{t('额度')}</Text>
</div> </div>
<InputNumber <InputNumber
placeholder={t('输入额度')} placeholder={t('输入额度')}
value={addQuotaLocal} value={adjustQuotaLocal}
min={adjustMode === 'override' ? undefined : 0}
onChange={(val) => { onChange={(val) => {
setAddQuotaLocal(val); const quota = val === '' || val == null ? '' : val;
setAddAmountLocal( setAdjustQuotaLocal(quota);
val != null && val !== '' setAdjustAmountLocal(
? Number( quota === ''
( ? ''
quotaToDisplayAmount(Math.abs(val)) * Math.sign(val) : adjustMode === 'override'
).toFixed(2), ? Number(quotaToDisplayAmount(quota).toFixed(6))
) : Number(quotaToDisplayAmount(Math.abs(quota)).toFixed(6)),
: '',
); );
}} }}
style={{ width: '100%' }} style={{ width: '100%' }}
+60 -57
View File
@@ -21,7 +21,6 @@ import React, { useEffect, useRef, useState } from 'react';
import { import {
Avatar, Avatar,
Typography, Typography,
Tag,
Card, Card,
Button, Button,
Banner, Banner,
@@ -32,6 +31,7 @@ import {
Col, Col,
Spin, Spin,
Tooltip, Tooltip,
Tag,
Tabs, Tabs,
TabPane, TabPane,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
@@ -88,8 +88,7 @@ const RechargeCard = ({
topupInfo, topupInfo,
onOpenHistory, onOpenHistory,
enableWaffoTopUp, enableWaffoTopUp,
waffoTopUp, enableWaffoPancakeTopUp,
waffoPayMethods,
subscriptionLoading = false, subscriptionLoading = false,
subscriptionPlans = [], subscriptionPlans = [],
billingPreference, billingPreference,
@@ -105,6 +104,7 @@ const RechargeCard = ({
const [activeTab, setActiveTab] = useState('topup'); const [activeTab, setActiveTab] = useState('topup');
const shouldShowSubscription = const shouldShowSubscription =
!subscriptionLoading && subscriptionPlans.length > 0; !subscriptionLoading && subscriptionPlans.length > 0;
const regularPayMethods = payMethods || [];
useEffect(() => { useEffect(() => {
if (initialTabSetRef.current) return; if (initialTabSetRef.current) return;
@@ -227,19 +227,31 @@ const RechargeCard = ({
<div className='py-8 flex justify-center'> <div className='py-8 flex justify-center'>
<Spin size='large' /> <Spin size='large' />
</div> </div>
) : enableOnlineTopUp || enableStripeTopUp || enableCreemTopUp || enableWaffoTopUp ? ( ) : enableOnlineTopUp ||
enableStripeTopUp ||
enableCreemTopUp ||
enableWaffoTopUp ||
enableWaffoPancakeTopUp ? (
<Form <Form
getFormApi={(api) => (onlineFormApiRef.current = api)} getFormApi={(api) => (onlineFormApiRef.current = api)}
initValues={{ topUpCount: topUpCount }} initValues={{ topUpCount: topUpCount }}
> >
<div className='space-y-6'> <div className='space-y-6'>
{(enableOnlineTopUp || enableStripeTopUp || enableWaffoTopUp) && ( {(enableOnlineTopUp ||
enableStripeTopUp ||
enableWaffoTopUp ||
enableWaffoPancakeTopUp) && (
<Row gutter={12}> <Row gutter={12}>
<Col xs={24} sm={24} md={24} lg={10} xl={10}> <Col xs={24} sm={24} md={24} lg={10} xl={10}>
<Form.InputNumber <Form.InputNumber
field='topUpCount' field='topUpCount'
label={t('充值数量')} label={t('充值数量')}
disabled={!enableOnlineTopUp && !enableStripeTopUp && !enableWaffoTopUp} disabled={
!enableOnlineTopUp &&
!enableStripeTopUp &&
!enableWaffoTopUp &&
!enableWaffoPancakeTopUp
}
placeholder={ placeholder={
t('充值数量,最低 ') + renderQuotaWithAmount(minTopUp) t('充值数量,最低 ') + renderQuotaWithAmount(minTopUp)
} }
@@ -291,16 +303,27 @@ const RechargeCard = ({
style={{ width: '100%' }} style={{ width: '100%' }}
/> />
</Col> </Col>
{payMethods && payMethods.filter(m => m.type !== 'waffo').length > 0 && ( {regularPayMethods.length > 0 && (
<Col xs={24} sm={24} md={24} lg={14} xl={14}> <Col xs={24} sm={24} md={24} lg={14} xl={14}>
<Form.Slot label={t('选择支付方式')}> <Form.Slot label={t('选择支付方式')}>
<Space wrap> <Space wrap>
{payMethods.filter(m => m.type !== 'waffo').map((payMethod) => { {regularPayMethods.map((payMethod) => {
const minTopupVal = Number(payMethod.min_topup) || 0; const minTopupVal =
Number(payMethod.min_topup) || 0;
const isStripe = payMethod.type === 'stripe'; const isStripe = payMethod.type === 'stripe';
const isWaffo =
typeof payMethod.type === 'string' &&
payMethod.type.startsWith('waffo:');
const isWaffoPancake =
payMethod.type === 'waffo_pancake';
const disabled = const disabled =
(!enableOnlineTopUp && !isStripe) || (!enableOnlineTopUp &&
!isStripe &&
!isWaffo &&
!isWaffoPancake) ||
(!enableStripeTopUp && isStripe) || (!enableStripeTopUp && isStripe) ||
(!enableWaffoTopUp && isWaffo) ||
(!enableWaffoPancakeTopUp && isWaffoPancake) ||
minTopupVal > Number(topUpCount || 0); minTopupVal > Number(topUpCount || 0);
const buttonEl = ( const buttonEl = (
@@ -320,6 +343,21 @@ const RechargeCard = ({
<SiWechat size={18} color='#07C160' /> <SiWechat size={18} color='#07C160' />
) : payMethod.type === 'stripe' ? ( ) : payMethod.type === 'stripe' ? (
<SiStripe size={18} color='#635BFF' /> <SiStripe size={18} color='#635BFF' />
) : payMethod.icon ? (
<img
src={payMethod.icon}
alt={payMethod.name}
style={{
width: 18,
height: 18,
objectFit: 'contain',
}}
/>
) : payMethod.type === 'waffo_pancake' ? (
<CreditCard
size={18}
color='var(--semi-color-primary)'
/>
) : ( ) : (
<CreditCard <CreditCard
size={18} size={18}
@@ -355,8 +393,8 @@ const RechargeCard = ({
); );
})} })}
</Space> </Space>
</Form.Slot> </Form.Slot>
</Col> </Col>
)} )}
</Row> </Row>
)} )}
@@ -388,7 +426,9 @@ const RechargeCard = ({
<div className='grid grid-cols-2 sm:grid-cols-3 md:grid-cols-4 gap-2'> <div className='grid grid-cols-2 sm:grid-cols-3 md:grid-cols-4 gap-2'>
{presetAmounts.map((preset, index) => { {presetAmounts.map((preset, index) => {
const discount = const discount =
preset.discount || topupInfo?.discount?.[preset.value] || 1.0; preset.discount ||
topupInfo?.discount?.[preset.value] ||
1.0;
const originalPrice = preset.value * priceRatio; const originalPrice = preset.value * priceRatio;
const discountedPrice = originalPrice * discount; const discountedPrice = originalPrice * discount;
const hasDiscount = discount < 1.0; const hasDiscount = discount < 1.0;
@@ -404,7 +444,7 @@ const RechargeCard = ({
const s = JSON.parse(statusStr); const s = JSON.parse(statusStr);
usdRate = s?.usd_exchange_rate || 7; usdRate = s?.usd_exchange_rate || 7;
} }
} catch (e) { } } catch (e) {}
let displayValue = preset.value; // let displayValue = preset.value; //
let displayActualPay = actualPay; let displayActualPay = actualPay;
@@ -455,7 +495,10 @@ const RechargeCard = ({
{hasDiscount && ( {hasDiscount && (
<Tag style={{ marginLeft: 4 }} color='green'> <Tag style={{ marginLeft: 4 }} color='green'>
{t('折').includes('off') {t('折').includes('off')
? ((1 - parseFloat(discount)) * 100).toFixed(1) ? (
(1 - parseFloat(discount)) *
100
).toFixed(1)
: (discount * 10).toFixed(1)} : (discount * 10).toFixed(1)}
{t('折')} {t('折')}
</Tag> </Tag>
@@ -482,46 +525,6 @@ const RechargeCard = ({
</Form.Slot> </Form.Slot>
)} )}
{/* Waffo 充值区域 */}
{enableWaffoTopUp &&
waffoPayMethods &&
waffoPayMethods.length > 0 && (
<Form.Slot label={t('Waffo 充值')}>
<Space wrap>
{waffoPayMethods.map((method, index) => (
<Button
key={index}
theme='outline'
type='tertiary'
onClick={() => waffoTopUp(index)}
loading={paymentLoading}
icon={
method.icon ? (
<img
src={method.icon}
alt={method.name}
style={{
width: 36,
height: 36,
objectFit: 'contain',
}}
/>
) : (
<CreditCard
size={18}
color='var(--semi-color-text-2)'
/>
)
}
className='!rounded-lg !px-4 !py-2'
>
{method.name}
</Button>
))}
</Space>
</Form.Slot>
)}
{/* Creem 充值区域 */} {/* Creem 充值区域 */}
{enableCreemTopUp && creemProducts.length > 0 && ( {enableCreemTopUp && creemProducts.length > 0 && (
<Form.Slot label={t('Creem 充值')}> <Form.Slot label={t('Creem 充值')}>
@@ -442,6 +442,14 @@ const SubscriptionPlansCard = ({
(subscription?.end_time || 0) * 1000, (subscription?.end_time || 0) * 1000,
).toLocaleString()} ).toLocaleString()}
</div> </div>
{isActive && subscription?.next_reset_time > 0 && (
<div className='text-xs text-gray-500 mb-2'>
{t('下一次重置')}:{' '}
{new Date(
subscription.next_reset_time * 1000,
).toLocaleString()}
</div>
)}
<div className='text-xs text-gray-500 mb-2'> <div className='text-xs text-gray-500 mb-2'>
{t('总额度')}:{' '} {t('总额度')}:{' '}
{totalAmount > 0 ? ( {totalAmount > 0 ? (
+195 -35
View File
@@ -75,6 +75,8 @@ const TopUp = () => {
const [enableWaffoTopUp, setEnableWaffoTopUp] = useState(false); const [enableWaffoTopUp, setEnableWaffoTopUp] = useState(false);
const [waffoPayMethods, setWaffoPayMethods] = useState([]); const [waffoPayMethods, setWaffoPayMethods] = useState([]);
const [waffoMinTopUp, setWaffoMinTopUp] = useState(1); const [waffoMinTopUp, setWaffoMinTopUp] = useState(1);
const [enableWaffoPancakeTopUp, setEnableWaffoPancakeTopUp] = useState(false);
const [waffoPancakeMinTopUp, setWaffoPancakeMinTopUp] = useState(1);
const [isSubmitting, setIsSubmitting] = useState(false); const [isSubmitting, setIsSubmitting] = useState(false);
const [open, setOpen] = useState(false); const [open, setOpen] = useState(false);
@@ -112,6 +114,39 @@ const TopUp = () => {
discount: {}, discount: {},
}); });
const confirmPayMethods = [
...payMethods,
...waffoPayMethods.map((method, index) => ({
...method,
type: `waffo:${index}`,
min_topup: waffoMinTopUp,
color: method.color || 'rgba(var(--semi-primary-5), 1)',
})),
];
const getPayMethodConfig = (payment) =>
confirmPayMethods.find((method) => method.type === payment);
const getPaymentMinTopUp = (payment) => {
const configuredMinTopUp = Number(getPayMethodConfig(payment)?.min_topup);
return Number.isFinite(configuredMinTopUp) && configuredMinTopUp > 0
? configuredMinTopUp
: minTopUp;
};
const requestAmountByPayment = async (payment, value) => {
if (payment === 'stripe') {
return getStripeAmount(value);
}
if (payment === 'waffo_pancake') {
return getWaffoPancakeAmount(value);
}
if (typeof payment === 'string' && payment.startsWith('waffo:')) {
return getWaffoAmount(value);
}
return getAmount(value);
};
const topUp = async () => { const topUp = async () => {
if (redemptionCode === '') { if (redemptionCode === '') {
showInfo(t('请输入兑换码!')); showInfo(t('请输入兑换码!'));
@@ -162,6 +197,16 @@ const TopUp = () => {
showError(t('管理员未开启Stripe充值!')); showError(t('管理员未开启Stripe充值!'));
return; return;
} }
} else if (payment === 'waffo_pancake') {
if (!enableWaffoPancakeTopUp) {
showError(t('管理员未开启 Waffo Pancake 充值!'));
return;
}
} else if (payment.startsWith('waffo:')) {
if (!enableWaffoTopUp) {
showError(t('管理员未开启 Waffo 充值!'));
return;
}
} else { } else {
if (!enableOnlineTopUp) { if (!enableOnlineTopUp) {
showError(t('管理员未开启在线充值!')); showError(t('管理员未开启在线充值!'));
@@ -172,14 +217,11 @@ const TopUp = () => {
setPayWay(payment); setPayWay(payment);
setPaymentLoading(true); setPaymentLoading(true);
try { try {
if (payment === 'stripe') { const selectedMinTopUp = getPaymentMinTopUp(payment);
await getStripeAmount(); await requestAmountByPayment(payment);
} else {
await getAmount();
}
if (topUpCount < minTopUp) { if (topUpCount < selectedMinTopUp) {
showError(t('充值数量不能小于') + minTopUp); showError(t('充值数量不能小于') + selectedMinTopUp);
return; return;
} }
setOpen(true); setOpen(true);
@@ -191,6 +233,29 @@ const TopUp = () => {
}; };
const onlineTopUp = async () => { const onlineTopUp = async () => {
if (payWay === 'waffo_pancake') {
setConfirmLoading(true);
try {
await waffoPancakeTopUp();
} finally {
setOpen(false);
setConfirmLoading(false);
}
return;
}
if (payWay.startsWith('waffo:')) {
const payMethodIndex = Number(payWay.split(':')[1]);
setConfirmLoading(true);
try {
await waffoTopUp(Number.isFinite(payMethodIndex) ? payMethodIndex : 0);
} finally {
setOpen(false);
setConfirmLoading(false);
}
return;
}
if (payWay === 'stripe') { if (payWay === 'stripe') {
// Stripe // Stripe
if (amount === 0) { if (amount === 0) {
@@ -317,32 +382,122 @@ const TopUp = () => {
const waffoTopUp = async (payMethodIndex) => { const waffoTopUp = async (payMethodIndex) => {
try { try {
if (topUpCount < waffoMinTopUp) { if (topUpCount < waffoMinTopUp) {
showError(t('充值数量不能小于') + waffoMinTopUp); showError(t('充值数量不能小于') + waffoMinTopUp);
return; return;
} }
setPaymentLoading(true); setPaymentLoading(true);
const requestBody = { const requestBody = {
amount: parseInt(topUpCount), amount: parseInt(topUpCount),
}; };
if (payMethodIndex != null) { if (payMethodIndex != null) {
requestBody.pay_method_index = payMethodIndex; requestBody.pay_method_index = payMethodIndex;
} }
const res = await API.post('/api/user/waffo/pay', requestBody); const res = await API.post('/api/user/waffo/pay', requestBody);
if (res !== undefined) { if (res !== undefined) {
const { message, data } = res.data; const { message, data } = res.data;
if (message === 'success' && data?.payment_url) { if (message === 'success' && data?.payment_url) {
window.open(data.payment_url, '_blank'); window.open(data.payment_url, '_blank');
} else {
showError(data || t('支付请求失败'));
}
} else { } else {
showError(res); showError(data || t('支付请求失败'));
} }
} else {
showError(res);
}
} catch (e) { } catch (e) {
showError(t('支付请求失败')); showError(t('支付请求失败'));
} finally { } finally {
setPaymentLoading(false); setPaymentLoading(false);
}
};
const getWaffoAmount = async (value) => {
if (value === undefined) {
value = topUpCount;
}
setAmountLoading(true);
try {
const res = await API.post('/api/user/waffo/amount', {
amount: parseInt(value),
});
if (res !== undefined) {
const { message, data } = res.data;
if (message === 'success') {
setAmount(parseFloat(data));
} else {
setAmount(0);
Toast.error({ content: '错误:' + data, id: 'getAmount' });
}
} else {
showError(res);
}
} catch (err) {
// amount fetch failed silently
} finally {
setAmountLoading(false);
}
};
const waffoPancakeTopUp = async () => {
const minTopUpValue = Number(waffoPancakeMinTopUp || 1);
if (topUpCount < minTopUpValue) {
showError(t('充值数量不能小于') + minTopUpValue);
return;
}
setPaymentLoading(true);
try {
const res = await API.post('/api/user/waffo-pancake/pay', {
amount: parseInt(topUpCount),
});
if (res !== undefined) {
const { message, data } = res.data;
if (message === 'success') {
const checkoutUrl = data?.checkout_url || '';
if (checkoutUrl) {
window.open(checkoutUrl, '_blank');
} else {
showError(t('支付请求失败'));
}
} else {
const errorMsg =
typeof data === 'string' ? data : message || t('支付请求失败');
showError(errorMsg);
}
} else {
showError(res);
}
} catch (e) {
showError(t('支付请求失败'));
} finally {
setPaymentLoading(false);
}
};
const getWaffoPancakeAmount = async (value) => {
if (value === undefined) {
value = topUpCount;
}
setAmountLoading(true);
try {
const res = await API.post('/api/user/waffo-pancake/amount', {
amount: parseInt(value),
});
if (res !== undefined) {
const { message, data } = res.data;
if (message === 'success') {
setAmount(parseFloat(data));
} else {
setAmount(0);
Toast.error({ content: '错误:' + data, id: 'getAmount' });
}
} else {
showError(res);
}
} catch (err) {
// amount fetch failed silently
} finally {
setAmountLoading(false);
} }
}; };
@@ -481,20 +636,26 @@ const TopUp = () => {
const enableStripeTopUp = data.enable_stripe_topup || false; const enableStripeTopUp = data.enable_stripe_topup || false;
const enableOnlineTopUp = data.enable_online_topup || false; const enableOnlineTopUp = data.enable_online_topup || false;
const enableCreemTopUp = data.enable_creem_topup || false; const enableCreemTopUp = data.enable_creem_topup || false;
const enableWaffoTopUp = data.enable_waffo_topup || false;
const enableWaffoPancakeTopUp =
data.enable_waffo_pancake_topup || false;
const minTopUpValue = enableOnlineTopUp const minTopUpValue = enableOnlineTopUp
? data.min_topup ? data.min_topup
: enableStripeTopUp : enableStripeTopUp
? data.stripe_min_topup ? data.stripe_min_topup
: data.enable_waffo_topup : enableWaffoTopUp
? data.waffo_min_topup ? data.waffo_min_topup
: enableWaffoPancakeTopUp
? data.waffo_pancake_min_topup
: 1; : 1;
setEnableOnlineTopUp(enableOnlineTopUp); setEnableOnlineTopUp(enableOnlineTopUp);
setEnableStripeTopUp(enableStripeTopUp); setEnableStripeTopUp(enableStripeTopUp);
setEnableCreemTopUp(enableCreemTopUp); setEnableCreemTopUp(enableCreemTopUp);
const enableWaffoTopUp = data.enable_waffo_topup || false;
setEnableWaffoTopUp(enableWaffoTopUp); setEnableWaffoTopUp(enableWaffoTopUp);
setWaffoPayMethods(data.waffo_pay_methods || []); setWaffoPayMethods(data.waffo_pay_methods || []);
setWaffoMinTopUp(data.waffo_min_topup || 1); setWaffoMinTopUp(data.waffo_min_topup || 1);
setEnableWaffoPancakeTopUp(enableWaffoPancakeTopUp);
setWaffoPancakeMinTopUp(data.waffo_pancake_min_topup || 1);
setMinTopUp(minTopUpValue); setMinTopUp(minTopUpValue);
setTopUpCount(minTopUpValue); setTopUpCount(minTopUpValue);
@@ -739,7 +900,7 @@ const TopUp = () => {
amountLoading={amountLoading} amountLoading={amountLoading}
renderAmount={renderAmount} renderAmount={renderAmount}
payWay={payWay} payWay={payWay}
payMethods={payMethods} payMethods={confirmPayMethods}
amountNumber={amount} amountNumber={amount}
discountRate={topupInfo?.discount?.[topUpCount] || 1.0} discountRate={topupInfo?.discount?.[topUpCount] || 1.0}
/> />
@@ -789,8 +950,7 @@ const TopUp = () => {
creemProducts={creemProducts} creemProducts={creemProducts}
creemPreTopUp={creemPreTopUp} creemPreTopUp={creemPreTopUp}
enableWaffoTopUp={enableWaffoTopUp} enableWaffoTopUp={enableWaffoTopUp}
waffoTopUp={waffoTopUp} enableWaffoPancakeTopUp={enableWaffoPancakeTopUp}
waffoPayMethods={waffoPayMethods}
presetAmounts={presetAmounts} presetAmounts={presetAmounts}
selectedPreset={selectedPreset} selectedPreset={selectedPreset}
selectPresetAmount={selectPresetAmount} selectPresetAmount={selectPresetAmount}
@@ -804,7 +964,7 @@ const TopUp = () => {
setSelectedPreset={setSelectedPreset} setSelectedPreset={setSelectedPreset}
renderAmount={renderAmount} renderAmount={renderAmount}
amountLoading={amountLoading} amountLoading={amountLoading}
payMethods={payMethods} payMethods={confirmPayMethods}
preTopUp={preTopUp} preTopUp={preTopUp}
paymentLoading={paymentLoading} paymentLoading={paymentLoading}
payWay={payWay} payWay={payWay}
@@ -140,6 +140,17 @@ const PaymentConfirmModal = ({
size={16} size={16}
color='#635BFF' color='#635BFF'
/> />
) : payMethod.icon ? (
<img
src={payMethod.icon}
alt={payMethod.name}
className='mr-2'
style={{
width: 16,
height: 16,
objectFit: 'contain',
}}
/>
) : ( ) : (
<CreditCard <CreditCard
className='mr-2' className='mr-2'
@@ -161,6 +161,16 @@ const TopupHistoryModal = ({ visible, onCancel, t }) => {
const columns = useMemo(() => { const columns = useMemo(() => {
const baseColumns = [ const baseColumns = [
...(userIsAdmin
? [
{
title: t('用户ID'),
dataIndex: 'user_id',
key: 'user_id',
render: (userId) => <Text>{userId ?? '-'}</Text>,
},
]
: []),
{ {
title: t('订单号'), title: t('订单号'),
dataIndex: 'trade_no', dataIndex: 'trade_no',
+29 -7
View File
@@ -1,3 +1,21 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import { getCurrencyConfig } from './render'; import { getCurrencyConfig } from './render';
export const getQuotaPerUnit = () => { export const getQuotaPerUnit = () => {
@@ -7,19 +25,23 @@ export const getQuotaPerUnit = () => {
export const quotaToDisplayAmount = (quota) => { export const quotaToDisplayAmount = (quota) => {
const q = Number(quota || 0); const q = Number(quota || 0);
if (!Number.isFinite(q) || q <= 0) return 0; if (!Number.isFinite(q) || q === 0) return 0;
const sign = Math.sign(q);
const abs = Math.abs(q);
const { type, rate } = getCurrencyConfig(); const { type, rate } = getCurrencyConfig();
if (type === 'TOKENS') return q; if (type === 'TOKENS') return q;
const usd = q / getQuotaPerUnit(); const usd = abs / getQuotaPerUnit();
if (type === 'USD') return usd; if (type === 'USD') return sign * usd;
return usd * (rate || 1); return sign * usd * (rate || 1);
}; };
export const displayAmountToQuota = (amount) => { export const displayAmountToQuota = (amount) => {
const val = Number(amount || 0); const val = Number(amount || 0);
if (!Number.isFinite(val) || val <= 0) return 0; if (!Number.isFinite(val) || val === 0) return 0;
const sign = Math.sign(val);
const abs = Math.abs(val);
const { type, rate } = getCurrencyConfig(); const { type, rate } = getCurrencyConfig();
if (type === 'TOKENS') return Math.round(val); if (type === 'TOKENS') return Math.round(val);
const usd = type === 'USD' ? val : val / (rate || 1); const usd = type === 'USD' ? abs : abs / (rate || 1);
return Math.round(usd * getQuotaPerUnit()); return sign * Math.round(usd * getQuotaPerUnit());
}; };
+5 -3
View File
@@ -890,7 +890,7 @@ export const useChannelsData = () => {
return Promise.resolve(); return Promise.resolve();
} }
const { success, message, time } = res.data; const { success, message, time, error_code } = res.data;
// //
setModelTestResults((prev) => ({ setModelTestResults((prev) => ({
@@ -900,6 +900,7 @@ export const useChannelsData = () => {
message, message,
time: time || 0, time: time || 0,
timestamp: Date.now(), timestamp: Date.now(),
errorCode: error_code || null,
}, },
})); }));
@@ -927,7 +928,7 @@ export const useChannelsData = () => {
); );
} }
} else { } else {
showError(`${t('模型')} ${model}: ${message}`); showError(message);
} }
} catch (error) { } catch (error) {
// //
@@ -939,9 +940,10 @@ export const useChannelsData = () => {
message: error.message || t('网络错误'), message: error.message || t('网络错误'),
time: 0, time: 0,
timestamp: Date.now(), timestamp: Date.now(),
errorCode: null,
}, },
})); }));
showError(`${t('模型')} ${model}: ${error.message || t('测试失败')}`); showError(error.message || t('测试失败'));
} finally { } finally {
// //
setTestingModels((prev) => { setTestingModels((prev) => {

Some files were not shown because too many files have changed in this diff Show More