Merge remote-tracking branch 'origin/main' into codex/redeem-subscription

# Conflicts:
#	web/classic/src/components/topup/index.jsx
This commit is contained in:
Lich-Mac-Mini
2026-04-28 16:06:03 +08:00
1394 changed files with 176148 additions and 3930 deletions
+94 -14
View File
@@ -20,6 +20,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/middleware"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/pkg/billingexpr"
"github.com/QuantumNous/new-api/relay"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
@@ -233,6 +234,15 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
info.IsChannelTest = true
info.InitChannelMeta(c)
err = attachTestBillingRequestInput(info, request)
if err != nil {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
}
}
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return testResult{
@@ -460,7 +470,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
}
}
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
if bodyErr := validateTestResponseBody(respBody, isStream); bodyErr != nil {
return testResult{
context: c,
localErr: bodyErr,
@@ -469,21 +479,11 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
}
info.SetEstimatePromptTokens(usage.PromptTokens)
quota := 0
if !priceData.UsePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
if priceData.ModelRatio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
}
quota, tieredResult := settleTestQuota(info, priceData, usage)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
ChannelId: channel.Id,
PromptTokens: usage.PromptTokens,
@@ -505,6 +505,50 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
}
}
func attachTestBillingRequestInput(info *relaycommon.RelayInfo, request dto.Request) error {
if info == nil {
return nil
}
input, err := helper.BuildBillingExprRequestInputFromRequest(request, info.RequestHeaders)
if err != nil {
return err
}
info.BillingRequestInput = &input
return nil
}
func settleTestQuota(info *relaycommon.RelayInfo, priceData types.PriceData, usage *dto.Usage) (int, *billingexpr.TieredResult) {
if usage != nil && info != nil && info.TieredBillingSnapshot != nil {
isClaudeUsageSemantic := usage.UsageSemantic == "anthropic" || info.GetFinalRequestRelayFormat() == types.RelayFormatClaude
usedVars := billingexpr.UsedVars(info.TieredBillingSnapshot.ExprString)
if ok, quota, result := service.TryTieredSettle(info, service.BuildTieredTokenParams(usage, isClaudeUsageSemantic, usedVars)); ok {
return quota, result
}
}
quota := 0
if !priceData.UsePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
if priceData.ModelRatio != 0 && quota <= 0 {
quota = 1
}
return quota, nil
}
return int(priceData.ModelPrice * common.QuotaPerUnit), nil
}
func buildTestLogOther(c *gin.Context, info *relaycommon.RelayInfo, priceData types.PriceData, usage *dto.Usage, tieredResult *billingexpr.TieredResult) map[string]interface{} {
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
if tieredResult != nil {
service.InjectTieredBillingInfo(other, info, tieredResult)
}
return other
}
func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
switch u := usageAny.(type) {
case *dto.Usage:
@@ -570,6 +614,42 @@ func detectErrorFromTestResponseBody(respBody []byte) error {
return nil
}
func validateStreamTestResponseBody(respBody []byte) error {
b := bytes.TrimSpace(respBody)
if len(b) == 0 {
return errors.New("stream response body is empty")
}
for _, line := range bytes.Split(b, []byte{'\n'}) {
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
return nil
}
return errors.New("stream response body does not contain a valid stream event")
}
func validateTestResponseBody(respBody []byte, isStream bool) error {
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
return bodyErr
}
if isStream {
return validateStreamTestResponseBody(respBody)
}
return nil
}
func shouldUseStreamForAutomaticChannelTest(channel *model.Channel) bool {
return channel != nil && channel.Type == constant.ChannelTypeCodex
}
func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
if len(jsonBytes) == 0 {
return ""
@@ -822,7 +902,7 @@ func testAllChannels(notify bool) error {
}
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
result := testChannel(channel, "", "", false)
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
+71
View File
@@ -0,0 +1,71 @@
package controller
import (
"net/http/httptest"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/pkg/billingexpr"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestSettleTestQuotaUsesTieredBilling(t *testing.T) {
info := &relaycommon.RelayInfo{
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: `param("stream") == true ? tier("stream", p * 3) : tier("base", p * 2)`,
ExprHash: billingexpr.ExprHashString(`param("stream") == true ? tier("stream", p * 3) : tier("base", p * 2)`),
GroupRatio: 1,
EstimatedTier: "stream",
QuotaPerUnit: common.QuotaPerUnit,
ExprVersion: 1,
},
BillingRequestInput: &billingexpr.RequestInput{
Body: []byte(`{"stream":true}`),
},
}
quota, result := settleTestQuota(info, types.PriceData{
ModelRatio: 1,
CompletionRatio: 2,
}, &dto.Usage{
PromptTokens: 1000,
})
require.Equal(t, 1500, quota)
require.NotNil(t, result)
require.Equal(t, "stream", result.MatchedTier)
}
func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
info := &relaycommon.RelayInfo{
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
BillingMode: "tiered_expr",
ExprString: `tier("base", p * 2)`,
},
ChannelMeta: &relaycommon.ChannelMeta{},
}
priceData := types.PriceData{
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
}
usage := &dto.Usage{
PromptTokensDetails: dto.InputTokenDetails{
CachedTokens: 12,
},
}
other := buildTestLogOther(ctx, info, priceData, usage, &billingexpr.TieredResult{
MatchedTier: "base",
})
require.Equal(t, "tiered_expr", other["billing_mode"])
require.Equal(t, "base", other["matched_tier"])
require.NotEmpty(t, other["expr_b64"])
}
+22 -2
View File
@@ -32,6 +32,26 @@ const (
channelUpstreamModelUpdateNotifyMaxFailedChannelIDs = 10
)
var channelUpstreamModelUpdateSelectFields = []string{
"id",
"name",
"type",
"key",
"status",
"base_url",
"models",
"model_mapping",
"settings",
"setting",
"other",
"group",
"priority",
"weight",
"tag",
"channel_info",
"header_override",
}
var (
channelUpstreamModelUpdateTaskOnce sync.Once
channelUpstreamModelUpdateTaskRunning atomic.Bool
@@ -521,7 +541,7 @@ func runChannelUpstreamModelUpdateTaskOnce() {
for {
var channels []*model.Channel
query := model.DB.
Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
Select(channelUpstreamModelUpdateSelectFields).
Where("status = ?", common.ChannelStatusEnabled).
Order("id asc").
Limit(channelUpstreamModelUpdateTaskBatchSize)
@@ -814,7 +834,7 @@ func collectPendingApplyUpstreamModelChanges(settings dto.ChannelOtherSettings)
func findEnabledChannelsAfterID(lastID int, batchSize int) ([]*model.Channel, error) {
var channels []*model.Channel
query := model.DB.
Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
Select(channelUpstreamModelUpdateSelectFields).
Where("status = ?", common.ChannelStatusEnabled).
Order("id asc").
Limit(batchSize)
@@ -81,6 +81,10 @@ func TestCollectPendingApplyUpstreamModelChanges(t *testing.T) {
require.Equal(t, []string{"old-model"}, pendingRemoveModels)
}
func TestChannelUpstreamModelUpdateSelectFieldsIncludeModelMapping(t *testing.T) {
require.Contains(t, channelUpstreamModelUpdateSelectFields, "model_mapping")
}
func TestNormalizeChannelModelMapping(t *testing.T) {
modelMapping := `{
" alias-model ": " upstream-model ",
+223
View File
@@ -0,0 +1,223 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type DiscordResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type DiscordUser struct {
UID string `json:"id"`
ID string `json:"username"`
Name string `json:"global_name"`
}
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := url.Values{}
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
}
defer res.Body.Close()
var discordResponse DiscordResponse
err = json.NewDecoder(res.Body).Decode(&discordResponse)
if err != nil {
return nil, err
}
if discordResponse.AccessToken == "" {
common.SysError("Discord 获取 Token 失败,请检查设置!")
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
}
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
common.SysError("Discord 获取用户信息失败!请检查设置!")
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
}
var discordUser DiscordUser
err = json.NewDecoder(res2.Body).Decode(&discordUser)
if err != nil {
return nil, err
}
if discordUser.UID == "" || discordUser.ID == "" {
common.SysError("Discord 获取用户信息为空!请检查设置!")
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
}
return &discordUser, nil
}
func DiscordOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
DiscordBind(c)
return
}
if !system_setting.GetDiscordSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
DiscordId: discordUser.UID,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
err := user.FillUserByDiscordId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
if discordUser.ID != "" {
user.Username = discordUser.ID
} else {
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if discordUser.Name != "" {
user.DisplayName = discordUser.Name
} else {
user.DisplayName = "Discord User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func DiscordBind(c *gin.Context) {
if !system_setting.GetDiscordSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
DiscordId: discordUser.UID,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Discord 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.DiscordId = discordUser.UID
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
}
+220
View File
@@ -0,0 +1,220 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type GitHubOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type GitHubUser struct {
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
}
func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse GitHubOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res2.Body.Close()
var githubUser GitHubUser
err = json.NewDecoder(res2.Body).Decode(&githubUser)
if err != nil {
return nil, err
}
if githubUser.Login == "" {
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
}
return &githubUser, nil
}
func GitHubOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
GitHubBind(c)
return
}
if !common.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
})
return
}
code := c.Query("code")
githubUser, err := getGitHubUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
GitHubId: githubUser.Login,
}
// IsGitHubIdAlreadyTaken is unscoped
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
// FillUserByGitHubId is scoped
err := user.FillUserByGitHubId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// if user.Id == 0 , user has been deleted
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
} else {
user.DisplayName = "GitHub User"
}
user.Email = githubUser.Email
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
affCode := session.Get("aff")
inviterId := 0
if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func GitHubBind(c *gin.Context) {
if !common.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
})
return
}
code := c.Query("code")
githubUser, err := getGitHubUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
GitHubId: githubUser.Login,
}
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 GitHub 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.GitHubId = githubUser.Login
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}
+268
View File
@@ -0,0 +1,268 @@
package controller
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type LinuxdoUser struct {
Id int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDOOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Linux DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
}
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Linux DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
}
func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
if code == "" {
return nil, errors.New("invalid code")
}
// Get access token using Basic auth
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
// Get redirect URI from request
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", redirectURI)
req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", basicAuth)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{Timeout: 5 * time.Second}
res, err := client.Do(req)
if err != nil {
return nil, errors.New("failed to connect to Linux DO server")
}
defer res.Body.Close()
var tokenRes struct {
AccessToken string `json:"access_token"`
Message string `json:"message"`
}
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
return nil, err
}
if tokenRes.AccessToken == "" {
return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
}
// Get user info
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
req, err = http.NewRequest("GET", userEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
req.Header.Set("Accept", "application/json")
res2, err := client.Do(req)
if err != nil {
return nil, errors.New("failed to get user info from Linux DO")
}
defer res2.Body.Close()
var linuxdoUser LinuxdoUser
if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
return nil, err
}
if linuxdoUser.Id == 0 {
return nil, errors.New("invalid user info returned")
}
return &linuxdoUser, nil
}
func LinuxdoOAuth(c *gin.Context) {
session := sessions.Default(c)
errorCode := c.Query("error")
if errorCode != "" {
errorDescription := c.Query("error_description")
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": errorDescription,
})
return
}
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDOOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Linux DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
}
// Check if user exists
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
err := user.FillUserByLinuxDOId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = linuxdoUser.Name
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
affCode := session.Get("aff")
inviterId := 0
if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
+1
View File
@@ -61,6 +61,7 @@ func GetStatus(c *gin.Context) {
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"theme": system_setting.GetThemeSettings().Frontend,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
+3 -5
View File
@@ -15,9 +15,9 @@ import (
"github.com/QuantumNous/new-api/relay/channel/minimax"
"github.com/QuantumNous/new-api/relay/channel/moonshot"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
@@ -134,8 +134,7 @@ func ListModels(c *gin.Context, modelType int) {
}
for allowModel, _ := range tokenModelLimit {
if !acceptUnsetRatioModel {
_, _, exist := ratio_setting.GetModelRatioOrPrice(allowModel)
if !exist {
if !helper.HasModelBillingConfig(allowModel) {
continue
}
}
@@ -182,8 +181,7 @@ func ListModels(c *gin.Context, modelType int) {
}
for _, modelName := range models {
if !acceptUnsetRatioModel {
_, _, exist := ratio_setting.GetModelRatioOrPrice(modelName)
if !exist {
if !helper.HasModelBillingConfig(modelName) {
continue
}
}
+242
View File
@@ -0,0 +1,242 @@
package controller
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/config"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
"github.com/glebarez/sqlite"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
type listModelsResponse struct {
Success bool `json:"success"`
Data []dto.OpenAIModels `json:"data"`
Object string `json:"object"`
}
func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
t.Helper()
initModelListColumnNames(t)
gin.SetMode(gin.TestMode)
common.UsingSQLite = true
common.UsingMySQL = false
common.UsingPostgreSQL = false
common.RedisEnabled = false
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
require.NoError(t, err)
model.DB = db
model.LOG_DB = db
require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{}))
t.Cleanup(func() {
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
})
return db
}
func initModelListColumnNames(t *testing.T) {
t.Helper()
originalIsMasterNode := common.IsMasterNode
originalSQLitePath := common.SQLitePath
originalUsingSQLite := common.UsingSQLite
originalUsingMySQL := common.UsingMySQL
originalUsingPostgreSQL := common.UsingPostgreSQL
originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN")
defer func() {
common.IsMasterNode = originalIsMasterNode
common.SQLitePath = originalSQLitePath
common.UsingSQLite = originalUsingSQLite
common.UsingMySQL = originalUsingMySQL
common.UsingPostgreSQL = originalUsingPostgreSQL
if hadSQLDSN {
require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN))
} else {
require.NoError(t, os.Unsetenv("SQL_DSN"))
}
}()
common.IsMasterNode = false
common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
common.UsingSQLite = false
common.UsingMySQL = false
common.UsingPostgreSQL = false
require.NoError(t, os.Setenv("SQL_DSN", "local"))
require.NoError(t, model.InitDB())
if model.DB != nil {
sqlDB, err := model.DB.DB()
if err == nil {
_ = sqlDB.Close()
}
}
}
func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) {
t.Helper()
saved := map[string]string{}
require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
if strings.HasPrefix(key, "billing_setting.") {
saved[key] = value
}
return nil
}))
t.Cleanup(func() {
require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
model.InvalidatePricingCache()
})
modeBytes, err := common.Marshal(modes)
require.NoError(t, err)
exprBytes, err := common.Marshal(exprs)
require.NoError(t, err)
require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
"billing_setting.billing_mode": string(modeBytes),
"billing_setting.billing_expr": string(exprBytes),
}))
model.InvalidatePricingCache()
}
func withSelfUseModeDisabled(t *testing.T) {
t.Helper()
original := operation_setting.SelfUseModeEnabled
operation_setting.SelfUseModeEnabled = false
t.Cleanup(func() {
operation_setting.SelfUseModeEnabled = original
})
}
func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} {
t.Helper()
require.Equal(t, http.StatusOK, recorder.Code)
var payload listModelsResponse
require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))
require.True(t, payload.Success)
require.Equal(t, "list", payload.Object)
ids := make(map[string]struct{}, len(payload.Data))
for _, item := range payload.Data {
ids[item.Id] = struct{}{}
}
return ids
}
func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
byName := make(map[string]model.Pricing, len(pricings))
for _, pricing := range pricings {
byName[pricing.ModelName] = pricing
}
return byName
}
func TestListModelsIncludesTieredBillingModel(t *testing.T) {
withSelfUseModeDisabled(t)
withTieredBillingConfig(t, map[string]string{
"zz-tiered-visible-model": "tiered_expr",
"zz-tiered-empty-expr-model": "tiered_expr",
"zz-tiered-missing-expr-model": "tiered_expr",
}, map[string]string{
"zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
"zz-tiered-empty-expr-model": " ",
})
db := setupModelListControllerTestDB(t)
require.NoError(t, db.Create(&model.User{
Id: 1001,
Username: "model-list-user",
Password: "password",
Group: "default",
Status: common.UserStatusEnabled,
}).Error)
require.NoError(t, db.Create(&[]model.Ability{
{Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true},
{Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true},
{Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true},
{Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true},
}).Error)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
ctx.Set("id", 1001)
ListModels(ctx, constant.ChannelTypeOpenAI)
ids := decodeListModelsResponse(t, recorder)
require.Contains(t, ids, "zz-tiered-visible-model")
require.NotContains(t, ids, "zz-tiered-empty-expr-model")
require.NotContains(t, ids, "zz-tiered-missing-expr-model")
require.NotContains(t, ids, "zz-unpriced-model")
pricingByName := pricingByModelName(model.GetPricing())
visiblePricing, ok := pricingByName["zz-tiered-visible-model"]
require.True(t, ok)
require.Equal(t, "tiered_expr", visiblePricing.BillingMode)
require.NotEmpty(t, visiblePricing.BillingExpr)
emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"]
require.True(t, ok)
require.Empty(t, emptyExprPricing.BillingMode)
require.Empty(t, emptyExprPricing.BillingExpr)
missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"]
require.True(t, ok)
require.Empty(t, missingExprPricing.BillingMode)
require.Empty(t, missingExprPricing.BillingExpr)
}
func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
withSelfUseModeDisabled(t)
withTieredBillingConfig(t, map[string]string{
"zz-token-tiered-visible-model": "tiered_expr",
"zz-token-tiered-empty-expr-model": "tiered_expr",
"zz-token-tiered-missing-expr-model": "tiered_expr",
}, map[string]string{
"zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
"zz-token-tiered-empty-expr-model": "",
})
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true)
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{
"zz-token-tiered-visible-model": true,
"zz-token-tiered-empty-expr-model": true,
"zz-token-tiered-missing-expr-model": true,
"zz-token-unpriced-model": true,
})
ListModels(ctx, constant.ChannelTypeOpenAI)
ids := decodeListModelsResponse(t, recorder)
require.Contains(t, ids, "zz-token-tiered-visible-model")
require.NotContains(t, ids, "zz-token-tiered-empty-expr-model")
require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
require.NotContains(t, ids, "zz-token-unpriced-model")
}
+228
View File
@@ -0,0 +1,228 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := url.Values{}
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
if oidcResponse.AccessToken == "" {
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
}
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}
+20 -2
View File
@@ -27,6 +27,15 @@ var completionRatioMetaOptionKeys = []string{
"AudioCompletionRatio",
}
func isVisiblePublicKeyOption(key string) bool {
switch key {
case "WaffoPancakeWebhookPublicKey", "WaffoPancakeWebhookTestKey":
return true
default:
return false
}
}
func collectModelNamesFromOptionValue(raw string, modelNames map[string]struct{}) {
if strings.TrimSpace(raw) == "" {
return
@@ -66,11 +75,12 @@ func GetOptions(c *gin.Context) {
common.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap {
value := common.Interface2String(v)
if strings.HasSuffix(k, "Token") ||
isSensitiveKey := strings.HasSuffix(k, "Token") ||
strings.HasSuffix(k, "Secret") ||
strings.HasSuffix(k, "Key") ||
strings.HasSuffix(k, "secret") ||
strings.HasSuffix(k, "api_key") {
strings.HasSuffix(k, "api_key")
if isSensitiveKey && !isVisiblePublicKeyOption(k) {
continue
}
options = append(options, &model.Option{
@@ -188,6 +198,14 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "theme.frontend":
if option.Value != "default" && option.Value != "classic" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的主题值,可选值:default(新版前端)、classic(经典前端)",
})
return
}
case "GroupRatio":
err = ratio_setting.CheckGroupRatio(option.Value.(string))
if err != nil {
+70
View File
@@ -36,6 +36,10 @@ func PasskeyRegisterBegin(c *gin.Context) {
return
}
if !requirePasskeyRegistrationVerification(c, user.Id) {
return
}
credential, err := model.GetPasskeyByUserID(user.Id)
if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) {
common.ApiError(c, err)
@@ -96,6 +100,10 @@ func PasskeyRegisterFinish(c *gin.Context) {
return
}
if !requirePasskeyRegistrationVerification(c, user.Id) {
return
}
wa, err := passkeysvc.BuildWebAuthn(c.Request)
if err != nil {
common.ApiError(c, err)
@@ -151,6 +159,10 @@ func PasskeyDelete(c *gin.Context) {
return
}
if !requirePasskeyDeleteVerification(c, user.Id) {
return
}
if err := model.DeletePasskeyByUserID(user.Id); err != nil {
common.ApiError(c, err)
return
@@ -474,6 +486,7 @@ func PasskeyVerifyFinish(c *gin.Context) {
// Mark passkey as ready; /api/verify will convert this into the final secure verification session.
session.Set(PasskeyReadySessionKey, time.Now().Unix())
session.Delete(SecureVerificationSessionKey)
session.Delete(secureVerificationMethodSessionKey)
if err := session.Save(); err != nil {
common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
return
@@ -504,3 +517,60 @@ func getSessionUser(c *gin.Context) (*model.User, error) {
}
return user, nil
}
func requirePasskeyRegistrationVerification(c *gin.Context, userID int) bool {
twoFA, err := model.GetTwoFAByUserId(userID)
if err != nil {
common.ApiError(c, err)
return false
}
if twoFA == nil || !twoFA.IsEnabled {
return true
}
return requireSecureVerificationMethod(c, secureVerificationMethod2FA)
}
func requirePasskeyDeleteVerification(c *gin.Context, userID int) bool {
twoFA, err := model.GetTwoFAByUserId(userID)
if err != nil {
common.ApiError(c, err)
return false
}
if twoFA != nil && twoFA.IsEnabled {
return requireSecureVerificationMethod(c, secureVerificationMethod2FA)
}
_, err = model.GetPasskeyByUserID(userID)
if err != nil {
if errors.Is(err, model.ErrPasskeyNotFound) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return false
}
common.ApiError(c, err)
return false
}
return requireSecureVerificationMethod(c, secureVerificationMethodPasskey)
}
func requireSecureVerificationMethod(c *gin.Context, method string) bool {
session := sessions.Default(c)
verifiedAt, ok := session.Get(SecureVerificationSessionKey).(int64)
if !ok || time.Now().Unix()-verifiedAt >= SecureVerificationTimeout {
session.Delete(SecureVerificationSessionKey)
session.Delete(secureVerificationMethodSessionKey)
_ = session.Save()
common.ApiErrorMsg(c, "请先完成安全验证")
return false
}
if verifiedMethod, ok := session.Get(secureVerificationMethodSessionKey).(string); !ok || verifiedMethod != method {
common.ApiErrorMsg(c, "请先完成对应的安全验证")
return false
}
return true
}
+100
View File
@@ -0,0 +1,100 @@
package controller
import (
"strings"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
)
func isStripeTopUpEnabled() bool {
return strings.TrimSpace(setting.StripeApiSecret) != "" &&
strings.TrimSpace(setting.StripeWebhookSecret) != "" &&
strings.TrimSpace(setting.StripePriceId) != ""
}
func isStripeWebhookConfigured() bool {
return strings.TrimSpace(setting.StripeWebhookSecret) != ""
}
func isStripeWebhookEnabled() bool {
return isStripeTopUpEnabled()
}
func isCreemTopUpEnabled() bool {
products := strings.TrimSpace(setting.CreemProducts)
return strings.TrimSpace(setting.CreemApiKey) != "" &&
products != "" &&
products != "[]"
}
func isCreemWebhookConfigured() bool {
return strings.TrimSpace(setting.CreemWebhookSecret) != ""
}
func isCreemWebhookEnabled() bool {
return isCreemTopUpEnabled() && isCreemWebhookConfigured()
}
func isWaffoTopUpEnabled() bool {
if !setting.WaffoEnabled {
return false
}
return isWaffoWebhookConfigured()
}
func isWaffoWebhookConfigured() bool {
if setting.WaffoSandbox {
return strings.TrimSpace(setting.WaffoSandboxApiKey) != "" &&
strings.TrimSpace(setting.WaffoSandboxPrivateKey) != "" &&
strings.TrimSpace(setting.WaffoSandboxPublicCert) != ""
}
return strings.TrimSpace(setting.WaffoApiKey) != "" &&
strings.TrimSpace(setting.WaffoPrivateKey) != "" &&
strings.TrimSpace(setting.WaffoPublicCert) != ""
}
func isWaffoWebhookEnabled() bool {
return isWaffoTopUpEnabled()
}
func isWaffoPancakeTopUpEnabled() bool {
if !setting.WaffoPancakeEnabled {
return false
}
return isWaffoPancakeWebhookConfigured() &&
strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
strings.TrimSpace(setting.WaffoPancakePrivateKey) != "" &&
strings.TrimSpace(setting.WaffoPancakeStoreID) != "" &&
strings.TrimSpace(setting.WaffoPancakeProductID) != ""
}
func isWaffoPancakeWebhookConfigured() bool {
currentWebhookKey := strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
if setting.WaffoPancakeSandbox {
currentWebhookKey = strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
}
return currentWebhookKey != ""
}
func isWaffoPancakeWebhookEnabled() bool {
return isWaffoPancakeTopUpEnabled()
}
func isEpayTopUpEnabled() bool {
return isEpayWebhookConfigured() && len(operation_setting.PayMethods) > 0
}
func isEpayWebhookConfigured() bool {
return strings.TrimSpace(operation_setting.PayAddress) != "" &&
strings.TrimSpace(operation_setting.EpayId) != "" &&
strings.TrimSpace(operation_setting.EpayKey) != ""
}
func isEpayWebhookEnabled() bool {
return isEpayTopUpEnabled()
}
@@ -0,0 +1,166 @@
package controller
import (
"testing"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/stretchr/testify/require"
)
func TestStripeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalAPISecret := setting.StripeApiSecret
originalWebhookSecret := setting.StripeWebhookSecret
originalPriceID := setting.StripePriceId
t.Cleanup(func() {
setting.StripeApiSecret = originalAPISecret
setting.StripeWebhookSecret = originalWebhookSecret
setting.StripePriceId = originalPriceID
})
setting.StripeWebhookSecret = ""
setting.StripeApiSecret = "sk_test_123"
setting.StripePriceId = "price_123"
require.False(t, isStripeWebhookEnabled())
setting.StripeWebhookSecret = "whsec_test"
require.True(t, isStripeWebhookEnabled())
setting.StripePriceId = ""
require.False(t, isStripeWebhookEnabled())
}
func TestCreemWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalAPIKey := setting.CreemApiKey
originalProducts := setting.CreemProducts
originalWebhookSecret := setting.CreemWebhookSecret
t.Cleanup(func() {
setting.CreemApiKey = originalAPIKey
setting.CreemProducts = originalProducts
setting.CreemWebhookSecret = originalWebhookSecret
})
setting.CreemWebhookSecret = ""
setting.CreemApiKey = "creem_api_key"
setting.CreemProducts = `[{"productId":"prod_123"}]`
require.False(t, isCreemWebhookEnabled())
setting.CreemWebhookSecret = "creem_secret"
require.True(t, isCreemWebhookEnabled())
setting.CreemProducts = "[]"
require.False(t, isCreemWebhookEnabled())
}
func TestWaffoWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalEnabled := setting.WaffoEnabled
originalSandbox := setting.WaffoSandbox
originalAPIKey := setting.WaffoApiKey
originalPrivateKey := setting.WaffoPrivateKey
originalPublicCert := setting.WaffoPublicCert
originalSandboxAPIKey := setting.WaffoSandboxApiKey
originalSandboxPrivateKey := setting.WaffoSandboxPrivateKey
originalSandboxPublicCert := setting.WaffoSandboxPublicCert
t.Cleanup(func() {
setting.WaffoEnabled = originalEnabled
setting.WaffoSandbox = originalSandbox
setting.WaffoApiKey = originalAPIKey
setting.WaffoPrivateKey = originalPrivateKey
setting.WaffoPublicCert = originalPublicCert
setting.WaffoSandboxApiKey = originalSandboxAPIKey
setting.WaffoSandboxPrivateKey = originalSandboxPrivateKey
setting.WaffoSandboxPublicCert = originalSandboxPublicCert
})
setting.WaffoEnabled = true
setting.WaffoSandbox = false
setting.WaffoApiKey = ""
setting.WaffoPrivateKey = "private"
setting.WaffoPublicCert = "public"
require.False(t, isWaffoWebhookEnabled())
setting.WaffoApiKey = "api"
require.True(t, isWaffoWebhookEnabled())
setting.WaffoEnabled = false
require.False(t, isWaffoWebhookEnabled())
setting.WaffoEnabled = true
setting.WaffoSandbox = true
setting.WaffoSandboxApiKey = ""
setting.WaffoSandboxPrivateKey = "sandbox_private"
setting.WaffoSandboxPublicCert = "sandbox_public"
require.False(t, isWaffoWebhookEnabled())
setting.WaffoSandboxApiKey = "sandbox_api"
require.True(t, isWaffoWebhookEnabled())
}
func TestWaffoPancakeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalEnabled := setting.WaffoPancakeEnabled
originalSandbox := setting.WaffoPancakeSandbox
originalMerchantID := setting.WaffoPancakeMerchantID
originalPrivateKey := setting.WaffoPancakePrivateKey
originalWebhookPublicKey := setting.WaffoPancakeWebhookPublicKey
originalWebhookTestKey := setting.WaffoPancakeWebhookTestKey
originalStoreID := setting.WaffoPancakeStoreID
originalProductID := setting.WaffoPancakeProductID
t.Cleanup(func() {
setting.WaffoPancakeEnabled = originalEnabled
setting.WaffoPancakeSandbox = originalSandbox
setting.WaffoPancakeMerchantID = originalMerchantID
setting.WaffoPancakePrivateKey = originalPrivateKey
setting.WaffoPancakeWebhookPublicKey = originalWebhookPublicKey
setting.WaffoPancakeWebhookTestKey = originalWebhookTestKey
setting.WaffoPancakeStoreID = originalStoreID
setting.WaffoPancakeProductID = originalProductID
})
setting.WaffoPancakeEnabled = true
setting.WaffoPancakeSandbox = false
setting.WaffoPancakeMerchantID = "merchant"
setting.WaffoPancakePrivateKey = "private"
setting.WaffoPancakeStoreID = "store"
setting.WaffoPancakeProductID = "product"
setting.WaffoPancakeWebhookPublicKey = ""
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeWebhookPublicKey = "public"
require.True(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeEnabled = false
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeEnabled = true
setting.WaffoPancakeSandbox = true
setting.WaffoPancakeWebhookTestKey = ""
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeWebhookTestKey = "test_public"
require.True(t, isWaffoPancakeWebhookEnabled())
}
func TestEpayWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalPayAddress := operation_setting.PayAddress
originalEpayID := operation_setting.EpayId
originalEpayKey := operation_setting.EpayKey
originalPayMethods := operation_setting.PayMethods
t.Cleanup(func() {
operation_setting.PayAddress = originalPayAddress
operation_setting.EpayId = originalEpayID
operation_setting.EpayKey = originalEpayKey
operation_setting.PayMethods = originalPayMethods
})
operation_setting.PayAddress = "https://pay.example.com"
operation_setting.EpayId = "epay_id"
operation_setting.EpayKey = ""
operation_setting.PayMethods = []map[string]string{{"type": "alipay"}}
require.False(t, isEpayWebhookEnabled())
operation_setting.EpayKey = "epay_key"
require.True(t, isEpayWebhookEnabled())
operation_setting.PayMethods = nil
require.False(t, isEpayWebhookEnabled())
}
+161 -46
View File
@@ -21,14 +21,16 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/billing_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/samber/lo"
"github.com/gin-gonic/gin"
)
const (
defaultTimeoutSeconds = 10
defaultEndpoint = "/api/ratio_config"
defaultEndpoint = "/api/pricing"
maxConcurrentFetches = 8
maxRatioConfigBytes = 10 << 20 // 10MB
floatEpsilon = 1e-9
@@ -59,7 +61,29 @@ func valuesEqual(a, b interface{}) bool {
return a == b
}
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
var pricingSyncFields = []string{
"model_ratio",
"completion_ratio",
"cache_ratio",
"create_cache_ratio",
"image_ratio",
"audio_ratio",
"audio_completion_ratio",
"model_price",
billing_setting.BillingModeField,
billing_setting.BillingExprField,
}
var numericPricingSyncFields = map[string]bool{
"model_ratio": true,
"completion_ratio": true,
"cache_ratio": true,
"create_cache_ratio": true,
"image_ratio": true,
"audio_ratio": true,
"audio_completion_ratio": true,
"model_price": true,
}
type upstreamResult struct {
Name string `json:"name"`
@@ -67,6 +91,54 @@ type upstreamResult struct {
Err string `json:"err,omitempty"`
}
func valueMap(value any) map[string]any {
switch typed := value.(type) {
case map[string]any:
return typed
case map[string]float64:
return lo.MapValues(typed, func(value float64, _ string) any { return value })
case map[string]string:
return lo.MapValues(typed, func(value string, _ string) any { return value })
default:
return nil
}
}
func asFloat64(value any) (float64, bool) {
switch typed := value.(type) {
case float64:
return typed, true
case float32:
return float64(typed), true
case int:
return float64(typed), true
case int64:
return float64(typed), true
case json.Number:
parsed, err := typed.Float64()
return parsed, err == nil
default:
return 0, false
}
}
func normalizeSyncValue(field string, value any) any {
if numericPricingSyncFields[field] {
if parsed, ok := asFloat64(value); ok {
return parsed
}
}
return value
}
func getLocalPricingSyncData() map[string]any {
data := billing_setting.GetPricingSyncData(map[string]any(ratio_setting.GetExposedData()))
data["image_ratio"] = ratio_setting.GetImageRatioCopy()
data["audio_ratio"] = ratio_setting.GetAudioRatioCopy()
data["audio_completion_ratio"] = ratio_setting.GetAudioCompletionRatioCopy()
return data
}
func FetchUpstreamRatios(c *gin.Context) {
var req dto.UpstreamRequest
if err := c.ShouldBindJSON(&req); err != nil {
@@ -293,7 +365,7 @@ func FetchUpstreamRatios(c *gin.Context) {
if err := common.Unmarshal(body.Data, &type1Data); err == nil {
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
isType1 := false
for _, rt := range ratioTypes {
for _, rt := range pricingSyncFields {
if _, ok := type1Data[rt]; ok {
isType1 = true
break
@@ -307,11 +379,18 @@ func FetchUpstreamRatios(c *gin.Context) {
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
var pricingItems []struct {
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
CompletionRatio float64 `json:"completion_ratio"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
CompletionRatio float64 `json:"completion_ratio"`
CacheRatio *float64 `json:"cache_ratio"`
CreateCacheRatio *float64 `json:"create_cache_ratio"`
ImageRatio *float64 `json:"image_ratio"`
AudioRatio *float64 `json:"audio_ratio"`
AudioCompletionRatio *float64 `json:"audio_completion_ratio"`
BillingMode string `json:"billing_mode"`
BillingExpr string `json:"billing_expr"`
}
if err := common.Unmarshal(body.Data, &pricingItems); err != nil {
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
@@ -321,9 +400,23 @@ func FetchUpstreamRatios(c *gin.Context) {
modelRatioMap := make(map[string]float64)
completionRatioMap := make(map[string]float64)
cacheRatioMap := make(map[string]float64)
createCacheRatioMap := make(map[string]float64)
imageRatioMap := make(map[string]float64)
audioRatioMap := make(map[string]float64)
audioCompletionRatioMap := make(map[string]float64)
modelPriceMap := make(map[string]float64)
billingModeMap := make(map[string]string)
billingExprMap := make(map[string]string)
for _, item := range pricingItems {
if item.ModelName == "" {
continue
}
if item.BillingMode == billing_setting.BillingModeTieredExpr && strings.TrimSpace(item.BillingExpr) != "" {
billingModeMap[item.ModelName] = billing_setting.BillingModeTieredExpr
billingExprMap[item.ModelName] = item.BillingExpr
}
if item.QuotaType == 1 {
modelPriceMap[item.ModelName] = item.ModelPrice
} else {
@@ -331,6 +424,21 @@ func FetchUpstreamRatios(c *gin.Context) {
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
completionRatioMap[item.ModelName] = item.CompletionRatio
}
if item.CacheRatio != nil {
cacheRatioMap[item.ModelName] = *item.CacheRatio
}
if item.CreateCacheRatio != nil {
createCacheRatioMap[item.ModelName] = *item.CreateCacheRatio
}
if item.ImageRatio != nil {
imageRatioMap[item.ModelName] = *item.ImageRatio
}
if item.AudioRatio != nil {
audioRatioMap[item.ModelName] = *item.AudioRatio
}
if item.AudioCompletionRatio != nil {
audioCompletionRatioMap[item.ModelName] = *item.AudioCompletionRatio
}
}
converted := make(map[string]any)
@@ -350,6 +458,21 @@ func FetchUpstreamRatios(c *gin.Context) {
}
converted["completion_ratio"] = compAny
}
if len(cacheRatioMap) > 0 {
converted["cache_ratio"] = valueMap(cacheRatioMap)
}
if len(createCacheRatioMap) > 0 {
converted["create_cache_ratio"] = valueMap(createCacheRatioMap)
}
if len(imageRatioMap) > 0 {
converted["image_ratio"] = valueMap(imageRatioMap)
}
if len(audioRatioMap) > 0 {
converted["audio_ratio"] = valueMap(audioRatioMap)
}
if len(audioCompletionRatioMap) > 0 {
converted["audio_completion_ratio"] = valueMap(audioCompletionRatioMap)
}
if len(modelPriceMap) > 0 {
priceAny := make(map[string]any, len(modelPriceMap))
@@ -358,6 +481,12 @@ func FetchUpstreamRatios(c *gin.Context) {
}
converted["model_price"] = priceAny
}
if len(billingModeMap) > 0 {
converted[billing_setting.BillingModeField] = valueMap(billingModeMap)
}
if len(billingExprMap) > 0 {
converted[billing_setting.BillingExprField] = valueMap(billingExprMap)
}
ch <- upstreamResult{Name: uniqueName, Data: converted}
}(chn)
@@ -366,7 +495,7 @@ func FetchUpstreamRatios(c *gin.Context) {
wg.Wait()
close(ch)
localData := ratio_setting.GetExposedData()
localData := getLocalPricingSyncData()
var testResults []dto.TestResult
var successfulChannels []struct {
@@ -412,22 +541,16 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
allModels := make(map[string]struct{})
for _, ratioType := range ratioTypes {
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
for modelName := range localRatio {
allModels[modelName] = struct{}{}
}
}
for _, field := range pricingSyncFields {
for modelName := range valueMap(localData[field]) {
allModels[modelName] = struct{}{}
}
}
for _, channel := range successfulChannels {
for _, ratioType := range ratioTypes {
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
for modelName := range upstreamRatio {
allModels[modelName] = struct{}{}
}
for _, field := range pricingSyncFields {
for modelName := range valueMap(channel.data[field]) {
allModels[modelName] = struct{}{}
}
}
}
@@ -438,10 +561,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
for _, channel := range successfulChannels {
confidenceMap[channel.name] = make(map[string]bool)
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
modelRatios := valueMap(channel.data["model_ratio"])
completionRatios := valueMap(channel.data["completion_ratio"])
if hasModelRatio && hasCompletionRatio {
if len(modelRatios) > 0 && len(completionRatios) > 0 {
// 遍历所有模型,检查是否满足不可信条件
for modelName := range allModels {
// 默认为可信
@@ -451,12 +574,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
if modelRatioVal, ok := modelRatios[modelName]; ok {
if completionRatioVal, ok := completionRatios[modelName]; ok {
// 转换为float64进行比较
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
confidenceMap[channel.name][modelName] = false
}
}
modelRatioFloat, modelRatioOK := asFloat64(modelRatioVal)
completionRatioFloat, completionRatioOK := asFloat64(completionRatioVal)
if modelRatioOK && completionRatioOK && nearlyEqual(modelRatioFloat, 37.5) && nearlyEqual(completionRatioFloat, 1.0) {
confidenceMap[channel.name][modelName] = false
}
}
}
@@ -470,14 +591,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
}
for modelName := range allModels {
for _, ratioType := range ratioTypes {
for _, ratioType := range pricingSyncFields {
var localValue interface{} = nil
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
if val, exists := localRatio[modelName]; exists {
localValue = val
}
}
if val, exists := valueMap(localData[ratioType])[modelName]; exists {
localValue = normalizeSyncValue(ratioType, val)
}
upstreamValues := make(map[string]interface{})
@@ -488,16 +605,14 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
for _, channel := range successfulChannels {
var upstreamValue interface{} = nil
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
if val, exists := upstreamRatio[modelName]; exists {
upstreamValue = val
hasUpstreamValue = true
if val, exists := valueMap(channel.data[ratioType])[modelName]; exists {
upstreamValue = normalizeSyncValue(ratioType, val)
hasUpstreamValue = true
if localValue != nil && !valuesEqual(localValue, val) {
hasDifference = true
} else if valuesEqual(localValue, val) {
upstreamValue = "same"
}
if localValue != nil && !valuesEqual(localValue, upstreamValue) {
hasDifference = true
} else if valuesEqual(localValue, upstreamValue) {
upstreamValue = "same"
}
}
if upstreamValue == nil && localValue == nil {
+7 -3
View File
@@ -13,7 +13,10 @@ import (
const (
// SecureVerificationSessionKey means the user has fully passed secure verification.
SecureVerificationSessionKey = "secure_verified_at"
SecureVerificationSessionKey = "secure_verified_at"
secureVerificationMethodSessionKey = "secure_verified_method"
secureVerificationMethod2FA = "2fa"
secureVerificationMethodPasskey = "passkey"
// PasskeyReadySessionKey means WebAuthn finished and /api/verify can finalize step-up verification.
PasskeyReadySessionKey = "secure_passkey_ready_at"
// SecureVerificationTimeout 验证有效期(秒)
@@ -120,7 +123,7 @@ func UniversalVerify(c *gin.Context) {
}
// 验证成功,在 session 中记录时间戳
now, err := setSecureVerificationSession(c)
now, err := setSecureVerificationSession(c, req.Method)
if err != nil {
common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
return
@@ -139,11 +142,12 @@ func UniversalVerify(c *gin.Context) {
})
}
func setSecureVerificationSession(c *gin.Context) (int64, error) {
func setSecureVerificationSession(c *gin.Context, method string) (int64, error) {
session := sessions.Default(c)
session.Delete(PasskeyReadySessionKey)
now := time.Now().Unix()
session.Set(SecureVerificationSessionKey, now)
session.Set(secureVerificationMethodSessionKey, method)
if err := session.Save(); err != nil {
return 0, err
}
+19 -16
View File
@@ -2,11 +2,13 @@ package controller
import (
"bytes"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
@@ -24,14 +26,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
// Keep body for debugging consistency (like RequestCreemPay)
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("read subscription creem pay req body err: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "read query error"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅支付请求读取失败 error=%q", err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "read query error"})
return
}
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
@@ -81,16 +83,17 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
// create pending order first
order := &model.SubscriptionOrder{
UserId: userId,
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: referenceId,
PaymentMethod: PaymentMethodCreem,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
UserId: userId,
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: referenceId,
PaymentMethod: model.PaymentMethodCreem,
PaymentProvider: model.PaymentProviderCreem,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := order.Insert(); err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
@@ -112,14 +115,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
Quota: 0,
}
checkoutUrl, err := genCreemLink(referenceId, product, user.Email, user.Username)
checkoutUrl, err := genCreemLink(c.Request.Context(), referenceId, product, user.Email, user.Username)
if err != nil {
log.Printf("获取Creem支付链接失败: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅支付链接创建失败 trade_no=%s product_id=%s error=%q", referenceId, product.ProductId, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": checkoutUrl,
+11 -10
View File
@@ -82,13 +82,14 @@ func SubscriptionRequestEpay(c *gin.Context) {
}
order := &model.SubscriptionOrder{
UserId: userId,
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: tradeNo,
PaymentMethod: req.PaymentMethod,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
UserId: userId,
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: tradeNo,
PaymentMethod: req.PaymentMethod,
PaymentProvider: model.PaymentProviderEpay,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := order.Insert(); err != nil {
common.ApiErrorMsg(c, "创建订单失败")
@@ -104,7 +105,7 @@ func SubscriptionRequestEpay(c *gin.Context) {
ReturnUrl: returnUrl,
})
if err != nil {
_ = model.ExpireSubscriptionOrder(tradeNo)
_ = model.ExpireSubscriptionOrder(tradeNo, model.PaymentProviderEpay)
common.ApiErrorMsg(c, "拉起支付失败")
return
}
@@ -156,7 +157,7 @@ func SubscriptionEpayNotify(c *gin.Context) {
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo)); err != nil {
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
_, _ = c.Writer.Write([]byte("fail"))
return
}
@@ -205,7 +206,7 @@ func SubscriptionEpayReturn(c *gin.Context) {
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo)); err != nil {
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail")
return
}
+10 -9
View File
@@ -2,12 +2,12 @@ package controller
import (
"fmt"
"log"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/system_setting"
@@ -78,19 +78,20 @@ func SubscriptionRequestStripePay(c *gin.Context) {
payLink, err := genStripeSubscriptionLink(referenceId, user.StripeCustomer, user.Email, plan.StripePriceId)
if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 订阅支付链接创建失败 trade_no=%s plan_id=%d error=%q", referenceId, plan.Id, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
order := &model.SubscriptionOrder{
UserId: userId,
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: referenceId,
PaymentMethod: PaymentMethodStripe,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
UserId: userId,
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: referenceId,
PaymentMethod: model.PaymentMethodStripe,
PaymentProvider: model.PaymentProviderStripe,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := order.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
+313
View File
@@ -0,0 +1,313 @@
package controller
import (
"context"
"encoding/json"
"fmt"
"io"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay"
"github.com/QuantumNous/new-api/relay/channel"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/ratio_setting"
)
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
cacheGetChannel, err := model.CacheGetChannel(channelId)
if err != nil {
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if errUpdate != nil {
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
adaptor := relay.GetTaskAdaptor(platform)
if adaptor == nil {
return fmt.Errorf("video adaptor not found")
}
info := &relaycommon.RelayInfo{}
info.ChannelMeta = &relaycommon.ChannelMeta{
ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
}
info.ApiKey = cacheGetChannel.Key
adaptor.Init(info)
for _, taskId := range taskIds {
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
}
}
return nil
}
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
proxy := channel.GetSetting().Proxy
task := taskM[taskId]
if task == nil {
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
key := channel.Key
privateData := task.PrivateData
if privateData.Key != "" {
key = privateData.Key
}
resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
"task_id": taskId,
"action": task.Action,
}, proxy)
if err != nil {
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
}
//if resp.StatusCode != http.StatusOK {
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
//}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
}
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
taskResult := &relaycommon.TaskInfo{}
// try parse as New API response format
var responseItems dto.TaskResponse[model.Task]
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
t := responseItems.Data
taskResult.TaskID = t.TaskID
taskResult.Status = string(t.Status)
taskResult.Url = t.FailReason
taskResult.Progress = t.Progress
taskResult.Reason = t.FailReason
task.Data = t.Data
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
} else {
task.Data = redactVideoResponseBody(responseBody)
}
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
now := time.Now().Unix()
if taskResult.Status == "" {
//return fmt.Errorf("task %s status is empty", taskId)
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
}
// 记录原本的状态,防止重复退款
shouldRefund := false
quota := task.Quota
preStatus := task.Status
task.Status = model.TaskStatus(taskResult.Status)
switch taskResult.Status {
case model.TaskStatusSubmitted:
task.Progress = "10%"
case model.TaskStatusQueued:
task.Progress = "20%"
case model.TaskStatusInProgress:
task.Progress = "30%"
if task.StartTime == 0 {
task.StartTime = now
}
case model.TaskStatusSuccess:
task.Progress = "100%"
if task.FinishTime == 0 {
task.FinishTime = now
}
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
task.FailReason = taskResult.Url
}
// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
if taskResult.TotalTokens > 0 {
// 获取模型名称
var taskData map[string]interface{}
if err := json.Unmarshal(task.Data, &taskData); err == nil {
if modelName, ok := taskData["model"].(string); ok && modelName != "" {
// 获取模型价格和倍率
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
// 只有配置了倍率(非固定价格)时才按 token 重新计费
if hasRatioSetting && modelRatio > 0 {
// 获取用户和组的倍率信息
group := task.Group
if group == "" {
user, err := model.GetUserById(task.UserId, false)
if err == nil {
group = user.Group
}
}
if group != "" {
groupRatio := ratio_setting.GetGroupRatio(group)
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
var finalGroupRatio float64
if hasUserGroupRatio {
finalGroupRatio = userGroupRatio
} else {
finalGroupRatio = groupRatio
}
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
// 计算差额
preConsumedQuota := task.Quota
quotaDelta := actualQuota - preConsumedQuota
if quotaDelta > 0 {
// 需要补扣费
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%stokens%d",
task.TaskID,
logger.LogQuota(quotaDelta),
logger.LogQuota(actualQuota),
logger.LogQuota(preConsumedQuota),
taskResult.TotalTokens,
))
if err := model.DecreaseUserQuota(task.UserId, quotaDelta, false); err != nil {
logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
} else {
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
task.Quota = actualQuota // 更新任务记录的实际扣费额度
// 记录消费日志
logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
modelRatio, finalGroupRatio, taskResult.TotalTokens,
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
} else if quotaDelta < 0 {
// 需要退还多扣的费用
refundQuota := -quotaDelta
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%stokens%d",
task.TaskID,
logger.LogQuota(refundQuota),
logger.LogQuota(actualQuota),
logger.LogQuota(preConsumedQuota),
taskResult.TotalTokens,
))
if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
} else {
task.Quota = actualQuota // 更新任务记录的实际扣费额度
// 记录退款日志
logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
modelRatio, finalGroupRatio, taskResult.TotalTokens,
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
} else {
// quotaDelta == 0, 预扣费刚好准确
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%stokens%d",
task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
}
}
}
}
}
}
case model.TaskStatusFailure:
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if task.FinishTime == 0 {
task.FinishTime = now
}
task.FailReason = taskResult.Reason
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
taskResult.Progress = "100%"
if quota != 0 {
if preStatus != model.TaskStatusFailure {
shouldRefund = true
} else {
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
}
}
default:
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
}
if taskResult.Progress != "" {
task.Progress = taskResult.Progress
}
if err := task.Update(); err != nil {
common.SysLog("UpdateVideoTask task error: " + err.Error())
shouldRefund = false
}
if shouldRefund {
// 任务失败且之前状态不是失败才退还额度,防止重复退还
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
return nil
}
func redactVideoResponseBody(body []byte) []byte {
var m map[string]any
if err := json.Unmarshal(body, &m); err != nil {
return body
}
resp, _ := m["response"].(map[string]any)
if resp != nil {
delete(resp, "bytesBase64Encoded")
if v, ok := resp["video"].(string); ok {
resp["video"] = truncateBase64(v)
}
if vs, ok := resp["videos"].([]any); ok {
for i := range vs {
if vm, ok := vs[i].(map[string]any); ok {
delete(vm, "bytesBase64Encoded")
}
}
}
}
b, err := json.Marshal(m)
if err != nil {
return body
}
return b
}
func truncateBase64(s string) string {
const maxKeep = 256
if len(s) <= maxKeep {
return s
}
return s[:maxKeep] + "..."
}
+271 -5
View File
@@ -2,10 +2,12 @@ package controller
import (
"bytes"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
"testing"
@@ -14,6 +16,8 @@ import (
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
@@ -38,7 +42,36 @@ type tokenKeyResponse struct {
Key string `json:"key"`
}
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
type sqliteColumnInfo struct {
Name string `gorm:"column:name"`
Type string `gorm:"column:type"`
}
type legacyToken struct {
Id int `gorm:"primaryKey"`
UserId int `gorm:"index"`
Key string `gorm:"column:key;type:char(48);uniqueIndex"`
Status int `gorm:"default:1"`
Name string `gorm:"index"`
CreatedTime int64 `gorm:"bigint"`
AccessedTime int64 `gorm:"bigint"`
ExpiredTime int64 `gorm:"bigint;default:-1"`
RemainQuota int `gorm:"default:0"`
UnlimitedQuota bool
ModelLimitsEnabled bool
ModelLimits string `gorm:"type:text"`
AllowIps *string `gorm:"default:''"`
UsedQuota int `gorm:"default:0"`
Group string `gorm:"column:group;default:''"`
CrossGroupRetry bool
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (legacyToken) TableName() string {
return "tokens"
}
func openTokenControllerTestDB(t *testing.T) *gorm.DB {
t.Helper()
gin.SetMode(gin.TestMode)
@@ -55,10 +88,6 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
model.DB = db
model.LOG_DB = db
if err := db.AutoMigrate(&model.Token{}); err != nil {
t.Fatalf("failed to migrate token table: %v", err)
}
t.Cleanup(func() {
sqlDB, err := db.DB()
if err == nil {
@@ -69,6 +98,69 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
return db
}
func migrateTokenControllerTestDB(t *testing.T, db *gorm.DB) {
t.Helper()
if err := db.AutoMigrate(&model.Token{}); err != nil {
t.Fatalf("failed to migrate token table: %v", err)
}
}
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
t.Helper()
db := openTokenControllerTestDB(t)
migrateTokenControllerTestDB(t, db)
return db
}
func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) (*gorm.DB, *bool) {
t.Helper()
gin.SetMode(gin.TestMode)
common.RedisEnabled = false
common.UsingSQLite = false
common.UsingMySQL = dialect == "mysql"
common.UsingPostgreSQL = dialect == "postgres"
var (
db *gorm.DB
err error
)
switch dialect {
case "mysql":
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
case "postgres":
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
default:
t.Fatalf("unsupported dialect %q", dialect)
}
if err != nil {
t.Fatalf("failed to open %s db: %v", dialect, err)
}
model.DB = db
model.LOG_DB = db
if db.Migrator().HasTable("tokens") {
t.Skipf("refusing to run %s migration compatibility test against external database because tokens table already exists", dialect)
}
managedTokensTable := new(bool)
t.Cleanup(func() {
if *managedTokensTable && db.Migrator().HasTable("tokens") {
_ = db.Migrator().DropTable("tokens")
}
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
})
return db, managedTokensTable
}
func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
t.Helper()
@@ -124,6 +216,180 @@ func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenA
return response
}
func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName string) string {
t.Helper()
var columns []sqliteColumnInfo
if err := db.Raw("PRAGMA table_info(" + tableName + ")").Scan(&columns).Error; err != nil {
t.Fatalf("failed to inspect %s schema: %v", tableName, err)
}
for _, column := range columns {
if column.Name == columnName {
return strings.ToLower(column.Type)
}
}
t.Fatalf("column %s not found in %s schema", columnName, tableName)
return ""
}
func getTokenKeyColumnType(t *testing.T, db *gorm.DB, dialect string) string {
t.Helper()
switch dialect {
case "sqlite":
return getSQLiteColumnType(t, db, "tokens", "key")
case "mysql":
var columnType string
if err := db.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
"tokens", "key").Scan(&columnType).Error; err != nil {
t.Fatalf("failed to inspect mysql token key column: %v", err)
}
return strings.ToLower(columnType)
case "postgres":
var dataType string
var maxLength sql.NullInt64
if err := db.Raw(`SELECT data_type, character_maximum_length
FROM information_schema.columns
WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
"tokens", "key").Row().Scan(&dataType, &maxLength); err != nil {
t.Fatalf("failed to inspect postgres token key column: %v", err)
}
switch strings.ToLower(dataType) {
case "character varying":
return fmt.Sprintf("varchar(%d)", maxLength.Int64)
case "character":
return fmt.Sprintf("char(%d)", maxLength.Int64)
default:
if maxLength.Valid {
return fmt.Sprintf("%s(%d)", strings.ToLower(dataType), maxLength.Int64)
}
return strings.ToLower(dataType)
}
default:
t.Fatalf("unsupported dialect %q", dialect)
return ""
}
}
func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string, managedTokensTable *bool) {
t.Helper()
legacyKey := strings.Repeat("a", 48)
longKey := strings.Repeat("b", 64)
if err := db.AutoMigrate(&legacyToken{}); err != nil {
t.Fatalf("failed to create legacy token schema: %v", err)
}
if managedTokensTable != nil {
*managedTokensTable = true
}
if err := db.Create(&legacyToken{
UserId: 7,
Key: legacyKey,
Status: common.TokenStatusEnabled,
Name: "legacy-token",
CreatedTime: 1,
AccessedTime: 1,
ExpiredTime: -1,
RemainQuota: 100,
UnlimitedQuota: true,
ModelLimitsEnabled: false,
ModelLimits: "",
AllowIps: common.GetPointer(""),
UsedQuota: 0,
Group: "default",
CrossGroupRetry: false,
}).Error; err != nil {
t.Fatalf("failed to seed legacy token row: %v", err)
}
if got := getTokenKeyColumnType(t, db, dialect); got != "char(48)" {
t.Fatalf("expected legacy key column type char(48), got %q", got)
}
migrateTokenControllerTestDB(t, db)
if got := getTokenKeyColumnType(t, db, dialect); got != "varchar(128)" {
t.Fatalf("expected migrated key column type varchar(128), got %q", got)
}
var migratedToken model.Token
if err := db.First(&migratedToken, "name = ?", "legacy-token").Error; err != nil {
t.Fatalf("failed to load migrated token row: %v", err)
}
if migratedToken.Key != legacyKey {
t.Fatalf("expected migrated token key %q, got %q", legacyKey, migratedToken.Key)
}
if migratedToken.Name != "legacy-token" {
t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
}
inserted := model.Token{
UserId: 8,
Name: "long-token",
Key: longKey,
Status: common.TokenStatusEnabled,
CreatedTime: 1,
AccessedTime: 1,
ExpiredTime: -1,
RemainQuota: 200,
UnlimitedQuota: true,
ModelLimitsEnabled: false,
ModelLimits: "",
AllowIps: common.GetPointer(""),
UsedQuota: 0,
Group: "default",
CrossGroupRetry: false,
}
if err := db.Create(&inserted).Error; err != nil {
t.Fatalf("failed to insert long token after migration: %v", err)
}
var fetched model.Token
if err := db.First(&fetched, "id = ?", inserted.Id).Error; err != nil {
t.Fatalf("failed to fetch long token after migration: %v", err)
}
if fetched.Key != longKey {
t.Fatalf("expected long token key %q, got %q", longKey, fetched.Key)
}
}
func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
db := setupTokenControllerTestDB(t)
if got := getTokenKeyColumnType(t, db, "sqlite"); got != "varchar(128)" {
t.Fatalf("expected key column type varchar(128), got %q", got)
}
}
func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
db := openTokenControllerTestDB(t)
runTokenMigrationCompatibilityTest(t, db, "sqlite", nil)
}
func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) {
dsn := os.Getenv("TEST_MYSQL_DSN")
if dsn == "" {
t.Skip("set TEST_MYSQL_DSN to run mysql migration compatibility test")
}
db, managedTokensTable := openTokenControllerExternalDB(t, "mysql", dsn)
runTokenMigrationCompatibilityTest(t, db, "mysql", managedTokensTable)
}
func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) {
dsn := os.Getenv("TEST_POSTGRES_DSN")
if dsn == "" {
t.Skip("set TEST_POSTGRES_DSN to run postgres migration compatibility test")
}
db, managedTokensTable := openTokenControllerExternalDB(t, "postgres", dsn)
runTokenMigrationCompatibilityTest(t, db, "postgres", managedTokensTable)
}
func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
db := setupTokenControllerTestDB(t)
token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
+97 -65
View File
@@ -2,7 +2,7 @@ package controller
import (
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"sync"
@@ -27,7 +27,7 @@ func GetTopUpInfo(c *gin.Context) {
payMethods := operation_setting.PayMethods
// 如果启用了 Stripe 支付,添加到支付方法列表
if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
if isStripeTopUpEnabled() {
// 检查是否已经包含 Stripe
hasStripe := false
for _, method := range payMethods {
@@ -49,19 +49,11 @@ func GetTopUpInfo(c *gin.Context) {
}
// 如果启用了 Waffo 支付,添加到支付方法列表
enableWaffo := setting.WaffoEnabled &&
((!setting.WaffoSandbox &&
setting.WaffoApiKey != "" &&
setting.WaffoPrivateKey != "" &&
setting.WaffoPublicCert != "") ||
(setting.WaffoSandbox &&
setting.WaffoSandboxApiKey != "" &&
setting.WaffoSandboxPrivateKey != "" &&
setting.WaffoSandboxPublicCert != ""))
enableWaffo := isWaffoTopUpEnabled()
if enableWaffo {
hasWaffo := false
for _, method := range payMethods {
if method["type"] == "waffo" {
if method["type"] == model.PaymentMethodWaffo {
hasWaffo = true
break
}
@@ -70,7 +62,7 @@ func GetTopUpInfo(c *gin.Context) {
if !hasWaffo {
waffoMethod := map[string]string{
"name": "Waffo (Global Payment)",
"type": "waffo",
"type": model.PaymentMethodWaffo,
"color": "rgba(var(--semi-blue-5), 1)",
"min_topup": strconv.Itoa(setting.WaffoMinTopUp),
}
@@ -78,24 +70,46 @@ func GetTopUpInfo(c *gin.Context) {
}
}
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
if enableWaffoPancake {
hasWaffoPancake := false
for _, method := range payMethods {
if method["type"] == model.PaymentMethodWaffoPancake {
hasWaffoPancake = true
break
}
}
if !hasWaffoPancake {
payMethods = append(payMethods, map[string]string{
"name": "Waffo Pancake",
"type": model.PaymentMethodWaffoPancake,
"color": "rgba(var(--semi-orange-5), 1)",
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
})
}
}
data := gin.H{
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
"enable_creem_topup": setting.CreemApiKey != "" && setting.CreemProducts != "[]",
"enable_waffo_topup": enableWaffo,
"enable_online_topup": isEpayTopUpEnabled(),
"enable_stripe_topup": isStripeTopUpEnabled(),
"enable_creem_topup": isCreemTopUpEnabled(),
"enable_waffo_topup": enableWaffo,
"enable_waffo_pancake_topup": enableWaffoPancake,
"waffo_pay_methods": func() interface{} {
if enableWaffo {
return setting.GetWaffoPayMethods()
}
return nil
}(),
"creem_products": setting.CreemProducts,
"pay_methods": payMethods,
"min_topup": operation_setting.MinTopUp,
"stripe_min_topup": setting.StripeMinTopUp,
"waffo_min_topup": setting.WaffoMinTopUp,
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
"creem_products": setting.CreemProducts,
"pay_methods": payMethods,
"min_topup": operation_setting.MinTopUp,
"stripe_min_topup": setting.StripeMinTopUp,
"waffo_min_topup": setting.WaffoMinTopUp,
"waffo_pancake_min_topup": setting.WaffoPancakeMinTopUp,
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
}
common.ApiSuccess(c, data)
}
@@ -167,28 +181,28 @@ func RequestEpay(c *gin.Context) {
var req EpayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getPayMoney(req.Amount, group)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "支付方式不存在"})
return
}
@@ -199,7 +213,7 @@ func RequestEpay(c *gin.Context) {
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
@@ -212,7 +226,8 @@ func RequestEpay(c *gin.Context) {
ReturnUrl: returnUrl,
})
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 拉起支付失败 user_id=%d trade_no=%s payment_method=%s amount=%d error=%q", id, tradeNo, req.PaymentMethod, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
amount := req.Amount
@@ -222,20 +237,23 @@ func RequestEpay(c *gin.Context) {
amount = dAmount.Div(dQuotaPerUnit).IntPart()
}
topUp := &model.TopUp{
UserId: id,
Amount: amount,
Money: payMoney,
TradeNo: tradeNo,
PaymentMethod: req.PaymentMethod,
CreateTime: time.Now().Unix(),
Status: "pending",
UserId: id,
Amount: amount,
Money: payMoney,
TradeNo: tradeNo,
PaymentMethod: req.PaymentMethod,
PaymentProvider: model.PaymentProviderEpay,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
err = topUp.Insert()
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 创建充值订单失败 user_id=%d trade_no=%s payment_method=%s amount=%d error=%q", id, tradeNo, req.PaymentMethod, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 充值订单创建成功 user_id=%d trade_no=%s payment_method=%s amount=%d money=%.2f uri=%q params=%q", id, tradeNo, req.PaymentMethod, req.Amount, payMoney, uri, common.GetJsonString(params)))
c.JSON(http.StatusOK, gin.H{"message": "success", "data": params, "url": uri})
}
// tradeNo lock
@@ -281,12 +299,18 @@ func UnlockOrder(tradeNo string) {
}
func EpayNotify(c *gin.Context) {
if !isEpayWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
_, _ = c.Writer.Write([]byte("fail"))
return
}
var params map[string]string
if c.Request.Method == "POST" {
// POST 请求:从 POST body 解析参数
if err := c.Request.ParseForm(); err != nil {
log.Println("易支付回调POST解析失败:", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook POST 表单解析失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
_, _ = c.Writer.Write([]byte("fail"))
return
}
@@ -301,54 +325,63 @@ func EpayNotify(c *gin.Context) {
return r
}, map[string]string{})
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 收到请求 path=%q client_ip=%s method=%s params=%q", c.Request.RequestURI, c.ClientIP(), c.Request.Method, common.GetJsonString(params)))
if len(params) == 0 {
log.Println("易支付回调参数为空")
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 参数为空 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
_, _ = c.Writer.Write([]byte("fail"))
return
}
client := GetEpayClient()
if client == nil {
log.Println("易支付回调失败 未找到配置信息")
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 client 未初始化 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
}
return
}
verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus {
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签成功 trade_no=%s callback_type=%s trade_status=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, verifyInfo.TradeStatus, c.ClientIP(), common.GetJsonString(verifyInfo)))
_, err := c.Writer.Write([]byte("success"))
if err != nil {
log.Println("易支付回调写入失败")
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 trade_no=%s client_ip=%s error=%q", verifyInfo.ServiceTradeNo, c.ClientIP(), err.Error()))
}
} else {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
}
if err != nil {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签失败 path=%q client_ip=%s verify_error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
} else {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签失败 path=%q client_ip=%s verify_status=false", c.Request.RequestURI, c.ClientIP()))
}
log.Println("易支付回调签名验证失败")
return
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp == nil {
log.Printf("易支付回调未找到订单: %v", verifyInfo)
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo)))
return
}
if topUp.PaymentMethod == "stripe" || topUp.PaymentMethod == "creem" || topUp.PaymentMethod == "waffo" {
log.Printf("易支付回调订单支付方式不匹配: %s, 订单号: %s", topUp.PaymentMethod, verifyInfo.ServiceTradeNo)
if topUp.PaymentProvider != model.PaymentProviderEpay {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付网关不匹配 trade_no=%s order_provider=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentProvider, verifyInfo.Type, c.ClientIP()))
return
}
if topUp.Status == "pending" {
topUp.Status = "success"
if topUp.Status == common.TopUpStatusPending {
if topUp.PaymentMethod != verifyInfo.Type {
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 实际支付方式与订单不同 trade_no=%s order_payment_method=%s actual_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
topUp.PaymentMethod = verifyInfo.Type
}
topUp.Status = common.TopUpStatusSuccess
err := topUp.Update()
if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp)
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 更新充值订单失败 trade_no=%s user_id=%d client_ip=%s error=%q topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), err.Error(), common.GetJsonString(topUp)))
return
}
//user, _ := model.GetUserById(topUp.UserId, false)
@@ -358,14 +391,14 @@ func EpayNotify(c *gin.Context) {
quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart())
err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 更新用户额度失败 trade_no=%s user_id=%d client_ip=%s quota_to_add=%d error=%q topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), quotaToAdd, err.Error(), common.GetJsonString(topUp)))
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 充值成功 trade_no=%s user_id=%d client_ip=%s quota_to_add=%d money=%.2f topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), quotaToAdd, topUp.Money, common.GetJsonString(topUp)))
model.RecordTopupLog(topUp.UserId, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money), c.ClientIP(), topUp.PaymentMethod, "epay")
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 忽略事件 trade_no=%s callback_type=%s trade_status=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, verifyInfo.TradeStatus, c.ClientIP(), common.GetJsonString(verifyInfo)))
}
}
@@ -373,26 +406,26 @@ func RequestAmount(c *gin.Context) {
var req AmountRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getPayMoney(req.Amount, group)
if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}
func GetUserTopUps(c *gin.Context) {
@@ -461,10 +494,9 @@ func AdminCompleteTopUp(c *gin.Context) {
LockOrder(req.TradeNo)
defer UnlockOrder(req.TradeNo)
if err := model.ManualCompleteTopUp(req.TradeNo); err != nil {
if err := model.ManualCompleteTopUp(req.TradeNo, c.ClientIP()); err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, nil)
}
+64 -72
View File
@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
@@ -9,10 +10,10 @@ import (
"errors"
"fmt"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"io"
"log"
"net/http"
"time"
@@ -20,10 +21,7 @@ import (
"github.com/thanhpk/randstr"
)
const (
PaymentMethodCreem = "creem"
CreemSignatureHeader = "creem-signature"
)
const CreemSignatureHeader = "creem-signature"
var creemAdaptor = &CreemAdaptor{}
@@ -37,9 +35,9 @@ func generateCreemSignature(payload string, secret string) string {
// 验证Creem webhook签名
func verifyCreemSignature(payload string, signature string, secret string) bool {
if secret == "" {
log.Printf("Creem webhook secret not set")
logger.LogWarn(context.Background(), fmt.Sprintf("Creem webhook secret 未配置 test_mode=%t signature=%q body=%q", setting.CreemTestMode, signature, payload))
if setting.CreemTestMode {
log.Printf("Skip Creem webhook sign verify in test mode")
logger.LogInfo(context.Background(), fmt.Sprintf("Creem webhook 验签已跳过 reason=test_mode signature=%q body=%q", signature, payload))
return true
}
return false
@@ -66,13 +64,13 @@ type CreemAdaptor struct {
}
func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
if req.PaymentMethod != PaymentMethodCreem {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
if req.PaymentMethod != model.PaymentMethodCreem {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.ProductId == "" {
c.JSON(200, gin.H{"message": "error", "data": "请选择产品"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "请选择产品"})
return
}
@@ -80,8 +78,8 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
var products []CreemProduct
err := json.Unmarshal([]byte(setting.CreemProducts), &products)
if err != nil {
log.Println("解析Creem产品列表失败", err)
c.JSON(200, gin.H{"message": "error", "data": "产品配置错误"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 产品配置解析失败 user_id=%d error=%q", c.GetInt("id"), err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "产品配置错误"})
return
}
@@ -95,7 +93,7 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
}
if selectedProduct == nil {
c.JSON(200, gin.H{"message": "error", "data": "产品不存在"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "产品不存在"})
return
}
@@ -108,33 +106,33 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
// 先创建订单记录,使用产品配置的金额和充值额度
topUp := &model.TopUp{
UserId: id,
Amount: selectedProduct.Quota, // 充值额度
Money: selectedProduct.Price, // 支付金额
TradeNo: referenceId,
PaymentMethod: PaymentMethodCreem,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
UserId: id,
Amount: selectedProduct.Quota, // 充值额度
Money: selectedProduct.Price, // 支付金额
TradeNo: referenceId,
PaymentMethod: model.PaymentMethodCreem,
PaymentProvider: model.PaymentProviderCreem,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
err = topUp.Insert()
if err != nil {
log.Printf("创建Creem订单失败: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 创建充值订单失败 user_id=%d trade_no=%s product_id=%s error=%q", id, referenceId, selectedProduct.ProductId, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
// 创建支付链接,传入用户邮箱
checkoutUrl, err := genCreemLink(referenceId, selectedProduct, user.Email, user.Username)
checkoutUrl, err := genCreemLink(c.Request.Context(), referenceId, selectedProduct, user.Email, user.Username)
if err != nil {
log.Printf("获取Creem支付链接失败: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 创建支付链接失败 user_id=%d trade_no=%s product_id=%s error=%q", id, referenceId, selectedProduct.ProductId, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
log.Printf("Creem订单创建成功 - 用户ID: %d, 订单号: %s, 产品: %s, 充值额度: %d, 支付金额: %.2f",
id, referenceId, selectedProduct.Name, selectedProduct.Quota, selectedProduct.Price)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值订单创建成功 user_id=%d trade_no=%s product_id=%s product_name=%q quota=%d money=%.2f", id, referenceId, selectedProduct.ProductId, selectedProduct.Name, selectedProduct.Quota, selectedProduct.Price))
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": checkoutUrl,
@@ -149,20 +147,19 @@ func RequestCreemPay(c *gin.Context) {
// 读取body内容用于打印,同时保留原始数据供后续使用
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("read creem pay req body err: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "read query error"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 支付请求读取失败 error=%q", err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "read query error"})
return
}
// 打印body内容
log.Printf("creem pay request body: %s", string(bodyBytes))
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 支付请求已收到 user_id=%d body=%q", c.GetInt("id"), string(bodyBytes)))
// 重新设置body供后续的ShouldBindJSON使用
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
err = c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
creemAdaptor.RequestPay(c, &req)
@@ -230,35 +227,37 @@ type CreemWebhookEvent struct {
}
func CreemWebhook(c *gin.Context) {
if !isCreemWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
// 读取body内容用于打印,同时保留原始数据供后续使用
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("读取Creem Webhook请求body失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusBadRequest)
return
}
// 获取签名头
signature := c.GetHeader(CreemSignatureHeader)
// 打印关键信息(避免输出完整敏感payload)
log.Printf("Creem Webhook - URI: %s", c.Request.RequestURI)
if setting.CreemTestMode {
log.Printf("Creem Webhook - Signature: %s , Body: %s", signature, bodyBytes)
} else if signature == "" {
log.Printf("Creem Webhook缺少签名头")
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
if signature == "" {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 缺少签名 path=%q client_ip=%s body=%q", c.Request.RequestURI, c.ClientIP(), string(bodyBytes)))
c.AbortWithStatus(http.StatusUnauthorized)
return
}
// 验证签名
if !verifyCreemSignature(string(bodyBytes), signature, setting.CreemWebhookSecret) {
log.Printf("Creem Webhook签名验证失败")
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 验签失败 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
c.AbortWithStatus(http.StatusUnauthorized)
return
}
log.Printf("Creem Webhook签名验证成功")
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 验签成功 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
// 重新设置body供后续的ShouldBindJSON使用
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
@@ -266,19 +265,19 @@ func CreemWebhook(c *gin.Context) {
// 解析新格式的webhook数据
var webhookEvent CreemWebhookEvent
if err := c.ShouldBindJSON(&webhookEvent); err != nil {
log.Printf("解析Creem Webhook参数失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem webhook 解析失败 path=%q client_ip=%s error=%q body=%q", c.Request.RequestURI, c.ClientIP(), err.Error(), string(bodyBytes)))
c.AbortWithStatus(http.StatusBadRequest)
return
}
log.Printf("Creem Webhook解析成功 - EventType: %s, EventId: %s", webhookEvent.EventType, webhookEvent.Id)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 解析成功 event_type=%s event_id=%s request_id=%s order_id=%s order_status=%s", webhookEvent.EventType, webhookEvent.Id, webhookEvent.Object.RequestId, webhookEvent.Object.Order.Id, webhookEvent.Object.Order.Status))
// 根据事件类型处理不同的webhook
switch webhookEvent.EventType {
case "checkout.completed":
handleCheckoutCompleted(c, &webhookEvent)
default:
log.Printf("忽略Creem Webhook事件类型: %s", webhookEvent.EventType)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 忽略事件 event_type=%s event_id=%s", webhookEvent.EventType, webhookEvent.Id))
c.Status(http.StatusOK)
}
}
@@ -287,7 +286,7 @@ func CreemWebhook(c *gin.Context) {
func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// 验证订单状态
if event.Object.Order.Status != "paid" {
log.Printf("订单状态不是已支付: %s, 跳过处理", event.Object.Order.Status)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订单状态未支付,忽略处理 request_id=%s order_id=%s order_status=%s", event.Object.RequestId, event.Object.Order.Id, event.Object.Order.Status))
c.Status(http.StatusOK)
return
}
@@ -295,7 +294,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// 获取引用ID(这是我们创建订单时传递的request_id)
referenceId := event.Object.RequestId
if referenceId == "" {
log.Println("Creem Webhook缺少request_id字段")
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 缺少 request_id event_id=%s order_id=%s", event.Id, event.Object.Order.Id))
c.AbortWithStatus(http.StatusBadRequest)
return
}
@@ -303,40 +302,35 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// Try complete subscription order first
LockOrder(referenceId)
defer UnlockOrder(referenceId)
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event)); err == nil {
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentProviderCreem, ""); err == nil {
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
c.Status(http.StatusOK)
return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
log.Printf("Creem订阅订单处理失败: %s, 订单号: %s", err.Error(), referenceId)
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理失败 trade_no=%s creem_order_id=%s error=%q", referenceId, event.Object.Order.Id, err.Error()))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
// 验证订单类型,目前只处理一次性付款(充值)
if event.Object.Order.Type != "onetime" {
log.Printf("暂不支持订单类型: %s, 跳过处理", event.Object.Order.Type)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 暂不支持订单类型,忽略处理 request_id=%s creem_order_id=%s order_type=%s", referenceId, event.Object.Order.Id, event.Object.Order.Type))
c.Status(http.StatusOK)
return
}
// 记录详细的支付信息
log.Printf("处理Creem支付完成 - 订单号: %s, Creem订单ID: %s, 支付金额: %d %s, 客户邮箱: <redacted>, 产品: %s",
referenceId,
event.Object.Order.Id,
event.Object.Order.AmountPaid,
event.Object.Order.Currency,
event.Object.Product.Name)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 支付完成回调 trade_no=%s creem_order_id=%s amount_paid=%d currency=%s product_name=%q customer_email=%q customer_name=%q", referenceId, event.Object.Order.Id, event.Object.Order.AmountPaid, event.Object.Order.Currency, event.Object.Product.Name, event.Object.Customer.Email, event.Object.Customer.Name))
// 查询本地订单确认存在
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Printf("Creem充值订单不存在: %s", referenceId)
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 充值订单不存在 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
c.AbortWithStatus(http.StatusBadRequest)
return
}
if topUp.Status != common.TopUpStatusPending {
log.Printf("Creem充值订单状态错误: %s, 当前状态: %s", referenceId, topUp.Status)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值订单状态非 pending,忽略处理 trade_no=%s status=%s creem_order_id=%s", referenceId, topUp.Status, event.Object.Order.Id))
c.Status(http.StatusOK) // 已处理过的订单,返回成功避免重复处理
return
}
@@ -347,21 +341,20 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// 防护性检查,确保邮箱和姓名不为空字符串
if customerEmail == "" {
log.Printf("警告:Creem回调客户邮箱为空 - 订单号: %s", referenceId)
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 回调客户邮箱为空 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
}
if customerName == "" {
log.Printf("警告:Creem回调客户姓名为空 - 订单号: %s", referenceId)
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 回调客户姓名为空 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
}
err := model.RechargeCreem(referenceId, customerEmail, customerName)
err := model.RechargeCreem(referenceId, customerEmail, customerName, c.ClientIP())
if err != nil {
log.Printf("Creem充值处理失败: %s, 订单号: %s", err.Error(), referenceId)
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 充值处理失败 trade_no=%s creem_order_id=%s client_ip=%s error=%q", referenceId, event.Object.Order.Id, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
log.Printf("Creem充值成功 - 订单号: %s, 充值额度: %d, 支付金额: %.2f",
referenceId, topUp.Amount, topUp.Money)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值成功 trade_no=%s creem_order_id=%s quota=%d money=%.2f client_ip=%s", referenceId, event.Object.Order.Id, topUp.Amount, topUp.Money, c.ClientIP()))
c.Status(http.StatusOK)
}
@@ -379,7 +372,7 @@ type CreemCheckoutResponse struct {
Id string `json:"id"`
}
func genCreemLink(referenceId string, product *CreemProduct, email string, username string) (string, error) {
func genCreemLink(ctx context.Context, referenceId string, product *CreemProduct, email string, username string) (string, error) {
if setting.CreemApiKey == "" {
return "", fmt.Errorf("未配置Creem API密钥")
}
@@ -388,7 +381,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
apiUrl := "https://api.creem.io/v1/checkouts"
if setting.CreemTestMode {
apiUrl = "https://test-api.creem.io/v1/checkouts"
log.Printf("使用Creem测试环境: %s", apiUrl)
logger.LogInfo(ctx, fmt.Sprintf("Creem 使用测试环境 api_url=%s", apiUrl))
}
// 构建请求数据,确保包含用户邮箱
@@ -424,8 +417,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", setting.CreemApiKey)
log.Printf("发送Creem支付请求 - URL: %s, 产品ID: %s, 用户邮箱: %s, 订单号: %s",
apiUrl, product.ProductId, email, referenceId)
logger.LogInfo(ctx, fmt.Sprintf("Creem 支付请求已发送 api_url=%s product_id=%s email=%q trade_no=%s", apiUrl, product.ProductId, email, referenceId))
// 发送请求
client := &http.Client{
@@ -443,7 +435,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
return "", fmt.Errorf("读取响应失败: %v", err)
}
log.Printf("Creem API resp - status code: %d, resp: %s", resp.StatusCode, string(body))
logger.LogInfo(ctx, fmt.Sprintf("Creem API 响应已收到 trade_no=%s status_code=%d body=%q", referenceId, resp.StatusCode, string(body)))
// 检查响应状态
if resp.StatusCode/100 != 2 {
@@ -460,6 +452,6 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
return "", fmt.Errorf("Creem API resp no checkout url ")
}
log.Printf("Creem 支付链接创建成功 - 订单号: %s, 支付链接: %s", referenceId, checkoutResp.CheckoutUrl)
logger.LogInfo(ctx, fmt.Sprintf("Creem 支付链接创建成功 trade_no=%s response_id=%s checkout_url=%q", referenceId, checkoutResp.Id, checkoutResp.CheckoutUrl))
return checkoutResp.CheckoutUrl, nil
}
+74 -75
View File
@@ -1,16 +1,17 @@
package controller
import (
"context"
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
@@ -23,10 +24,6 @@ import (
"github.com/thanhpk/randstr"
)
const (
PaymentMethodStripe = "stripe"
)
var stripeAdaptor = &StripeAdaptor{}
// StripePayRequest represents a payment request for Stripe checkout.
@@ -48,34 +45,34 @@ type StripeAdaptor struct {
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
if req.Amount < getStripeMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getStripePayMoney(float64(req.Amount), group)
if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
if req.PaymentMethod != PaymentMethodStripe {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
if req.PaymentMethod != model.PaymentMethodStripe {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.Amount < getStripeMinTopup() {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
c.JSON(http.StatusOK, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
return
}
if req.Amount > 10000 {
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
c.JSON(http.StatusOK, gin.H{"message": "充值数量不能大于 10000", "data": 10})
return
}
@@ -98,26 +95,29 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount, req.SuccessURL, req.CancelURL)
if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 创建 Checkout Session 失败 user_id=%d trade_no=%s amount=%d error=%q", id, referenceId, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
topUp := &model.TopUp{
UserId: id,
Amount: req.Amount,
Money: chargedMoney,
TradeNo: referenceId,
PaymentMethod: PaymentMethodStripe,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
UserId: id,
Amount: req.Amount,
Money: chargedMoney,
TradeNo: referenceId,
PaymentMethod: model.PaymentMethodStripe,
PaymentProvider: model.PaymentProviderStripe,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
err = topUp.Insert()
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, referenceId, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
c.JSON(200, gin.H{
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Stripe 充值订单创建成功 user_id=%d trade_no=%s amount=%d money=%.2f", id, referenceId, req.Amount, chargedMoney))
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"pay_link": payLink,
@@ -129,7 +129,7 @@ func RequestStripeAmount(c *gin.Context) {
var req StripePayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
stripeAdaptor.RequestAmount(c, &req)
@@ -139,89 +139,93 @@ func RequestStripePay(c *gin.Context) {
var req StripePayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
stripeAdaptor.RequestPay(c, &req)
}
func StripeWebhook(c *gin.Context) {
if setting.StripeWebhookSecret == "" {
log.Println("Stripe Webhook Secret 未配置,拒绝处理")
ctx := c.Request.Context()
if !isStripeWebhookEnabled() {
logger.LogWarn(ctx, fmt.Sprintf("Stripe webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
payload, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
logger.LogError(ctx, fmt.Sprintf("Stripe webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusServiceUnavailable)
return
}
signature := c.GetHeader("Stripe-Signature")
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(payload)))
event, err := webhook.ConstructEventWithOptions(payload, signature, setting.StripeWebhookSecret, webhook.ConstructEventOptions{
IgnoreAPIVersionMismatch: true,
})
if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err)
logger.LogWarn(ctx, fmt.Sprintf("Stripe webhook 验签失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusBadRequest)
return
}
callerIp := c.ClientIP()
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 验签成功 event_type=%s client_ip=%s path=%q", string(event.Type), callerIp, c.Request.RequestURI))
switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(event)
sessionCompleted(ctx, event, callerIp)
case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(event)
sessionExpired(ctx, event)
case stripe.EventTypeCheckoutSessionAsyncPaymentSucceeded:
sessionAsyncPaymentSucceeded(event)
sessionAsyncPaymentSucceeded(ctx, event, callerIp)
case stripe.EventTypeCheckoutSessionAsyncPaymentFailed:
sessionAsyncPaymentFailed(event)
sessionAsyncPaymentFailed(ctx, event, callerIp)
default:
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 忽略事件 event_type=%s client_ip=%s", string(event.Type), callerIp))
}
c.Status(http.StatusOK)
}
func sessionCompleted(event stripe.Event) {
func sessionCompleted(ctx context.Context, event stripe.Event, callerIp string) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "complete" != status {
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
logger.LogWarn(ctx, fmt.Sprintf("Stripe checkout.completed 状态异常,忽略处理 trade_no=%s status=%s client_ip=%s", referenceId, status, callerIp))
return
}
paymentStatus := event.GetObjectValue("payment_status")
if paymentStatus != "paid" {
log.Printf("Stripe Checkout 支付未完成,payment_status: %s, ref: %s(等待异步支付结果)", paymentStatus, referenceId)
logger.LogInfo(ctx, fmt.Sprintf("Stripe Checkout 支付未完成,等待异步结果 trade_no=%s payment_status=%s client_ip=%s", referenceId, paymentStatus, callerIp))
return
}
fulfillOrder(event, referenceId, customerId)
fulfillOrder(ctx, event, referenceId, customerId, callerIp)
}
// sessionAsyncPaymentSucceeded handles delayed payment methods (bank transfer, SEPA, etc.)
// that confirm payment after the checkout session completes.
func sessionAsyncPaymentSucceeded(event stripe.Event) {
func sessionAsyncPaymentSucceeded(ctx context.Context, event stripe.Event, callerIp string) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
log.Printf("Stripe 异步支付成功: %s", referenceId)
logger.LogInfo(ctx, fmt.Sprintf("Stripe 异步支付成功 trade_no=%s client_ip=%s", referenceId, callerIp))
fulfillOrder(event, referenceId, customerId)
fulfillOrder(ctx, event, referenceId, customerId, callerIp)
}
// sessionAsyncPaymentFailed marks orders as failed when delayed payment methods
// ultimately fail (e.g. bank transfer not received, SEPA rejected).
func sessionAsyncPaymentFailed(event stripe.Event) {
func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp string) {
referenceId := event.GetObjectValue("client_reference_id")
log.Printf("Stripe 异步支付失败: %s", referenceId)
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败 trade_no=%s client_ip=%s", referenceId, callerIp))
if len(referenceId) == 0 {
log.Println("异步支付失败事件未提供支付单号")
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败事件缺少订单号 client_ip=%s", callerIp))
return
}
@@ -230,32 +234,32 @@ func sessionAsyncPaymentFailed(event stripe.Event) {
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Println("异步支付失败,充值订单不存在:", referenceId)
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但本地订单不存在 trade_no=%s client_ip=%s", referenceId, callerIp))
return
}
if topUp.PaymentMethod != PaymentMethodStripe {
log.Printf("异步支付失败订单支付方式不匹配: %s, ref: %s", topUp.PaymentMethod, referenceId)
if topUp.PaymentProvider != model.PaymentProviderStripe {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败订单支付网关不匹配 trade_no=%s payment_provider=%s client_ip=%s", referenceId, topUp.PaymentProvider, callerIp))
return
}
if topUp.Status != common.TopUpStatusPending {
log.Printf("异步支付失败订单状态非pending: %s, ref: %s", topUp.Status, referenceId)
logger.LogInfo(ctx, fmt.Sprintf("Stripe 异步支付失败订单状态非 pending,忽略处理 trade_no=%s status=%s client_ip=%s", referenceId, topUp.Status, callerIp))
return
}
topUp.Status = common.TopUpStatusFailed
if err := topUp.Update(); err != nil {
log.Printf("标记充值订单失败出错: %v, ref: %s", err, referenceId)
logger.LogError(ctx, fmt.Sprintf("Stripe 标记充值订单失败状态失败 trade_no=%s client_ip=%s error=%q", referenceId, callerIp, err.Error()))
return
}
log.Printf("充值订单已标记为失败: %s", referenceId)
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值订单已标记为失败 trade_no=%s client_ip=%s", referenceId, callerIp))
}
// fulfillOrder is the shared logic for crediting quota after payment is confirmed.
func fulfillOrder(event stripe.Event, referenceId string, customerId string) {
func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, customerId string, callerIp string) {
if len(referenceId) == 0 {
log.Println("未提供支付单号")
logger.LogWarn(ctx, fmt.Sprintf("Stripe 完成订单时缺少订单号 client_ip=%s", callerIp))
return
}
@@ -267,65 +271,60 @@ func fulfillOrder(event stripe.Event, referenceId string, customerId string) {
"currency": strings.ToUpper(event.GetObjectValue("currency")),
"event_type": string(event.Type),
}
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload)); err == nil {
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentProviderStripe, ""); err == nil {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp))
return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
log.Println("complete subscription order failed:", err.Error(), referenceId)
logger.LogError(ctx, fmt.Sprintf("Stripe 订阅订单处理失败 trade_no=%s event_type=%s client_ip=%s error=%q", referenceId, string(event.Type), callerIp, err.Error()))
return
}
err := model.Recharge(referenceId, customerId)
err := model.Recharge(referenceId, customerId, callerIp)
if err != nil {
log.Println(err.Error(), referenceId)
logger.LogError(ctx, fmt.Sprintf("Stripe 充值处理失败 trade_no=%s event_type=%s client_ip=%s error=%q", referenceId, string(event.Type), callerIp, err.Error()))
return
}
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency"))
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值成功 trade_no=%s amount_total=%.2f currency=%s event_type=%s client_ip=%s", referenceId, total/100, currency, string(event.Type), callerIp))
}
func sessionExpired(event stripe.Event) {
func sessionExpired(ctx context.Context, event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "expired" != status {
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
logger.LogWarn(ctx, fmt.Sprintf("Stripe checkout.expired 状态异常,忽略处理 trade_no=%s status=%s", referenceId, status))
return
}
if len(referenceId) == 0 {
log.Println("未提供支付单号")
logger.LogWarn(ctx, "Stripe checkout.expired 缺少订单号")
return
}
// Subscription order expiration
LockOrder(referenceId)
defer UnlockOrder(referenceId)
if err := model.ExpireSubscriptionOrder(referenceId); err == nil {
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentProviderStripe); err == nil {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId))
return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
log.Println("过期订阅订单失败", referenceId, ", err:", err.Error())
logger.LogError(ctx, fmt.Sprintf("Stripe 订阅订单过期处理失败 trade_no=%s error=%q", referenceId, err.Error()))
return
}
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Println("充值订单不存在", referenceId)
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentProviderStripe, common.TopUpStatusExpired)
if errors.Is(err, model.ErrTopUpNotFound) {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId))
return
}
if topUp.Status != common.TopUpStatusPending {
log.Println("充值订单状态错误", referenceId)
}
topUp.Status = common.TopUpStatusExpired
err := topUp.Update()
if err != nil {
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
logger.LogError(ctx, fmt.Sprintf("Stripe 充值订单过期处理失败 trade_no=%s error=%q", referenceId, err.Error()))
return
}
log.Println("充值订单已过期", referenceId)
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值订单已过期 trade_no=%s", referenceId))
}
// genStripeLink generates a Stripe Checkout session URL for payment.
+80 -42
View File
@@ -1,14 +1,15 @@
package controller
import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
@@ -99,28 +100,57 @@ type WaffoPayRequest struct {
PayMethodName string `json:"pay_method_name"` // Deprecated: 兼容旧前端,优先使用 pay_method_index
}
func RequestWaffoAmount(c *gin.Context) {
var req WaffoPayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
waffoMinTopup := int64(setting.WaffoMinTopUp)
if req.Amount < waffoMinTopup {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getWaffoPayMoney(float64(req.Amount), group)
if payMoney <= 0.01 {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}
// RequestWaffoPay 创建 Waffo 支付订单
func RequestWaffoPay(c *gin.Context) {
if !setting.WaffoEnabled {
c.JSON(200, gin.H{"message": "error", "data": "Waffo 支付未启用"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo 支付未启用"})
return
}
var req WaffoPayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
waffoMinTopup := int64(setting.WaffoMinTopUp)
if req.Amount < waffoMinTopup {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
return
}
id := c.GetInt("id")
user, err := model.GetUserById(id, false)
if err != nil || user == nil {
c.JSON(200, gin.H{"message": "error", "data": "用户不存在"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "用户不存在"})
return
}
@@ -131,8 +161,8 @@ func RequestWaffoPay(c *gin.Context) {
// 新协议:按索引查找
idx := *req.PayMethodIndex
if idx < 0 || idx >= len(methods) {
log.Printf("Waffo 无效的支付方式索引: %d, UserId=%d, 可用范围: [0, %d)", idx, id, len(methods))
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"})
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 支付方式索引无效 user_id=%d pay_method_index=%d method_count=%d", id, idx, len(methods)))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付方式"})
return
}
resolvedPayMethodType = methods[idx].PayMethodType
@@ -149,8 +179,8 @@ func RequestWaffoPay(c *gin.Context) {
}
}
if !valid {
log.Printf("Waffo 无效的支付方式: PayMethodType=%s, PayMethodName=%s, UserId=%d", req.PayMethodType, req.PayMethodName, id)
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"})
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 支付方式无效 user_id=%d pay_method_type=%s pay_method_name=%q", id, req.PayMethodType, req.PayMethodName))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付方式"})
return
}
}
@@ -159,7 +189,7 @@ func RequestWaffoPay(c *gin.Context) {
group, _ := model.GetUserGroup(id, true)
payMoney := getWaffoPayMoney(float64(req.Amount), group)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
@@ -178,26 +208,27 @@ func RequestWaffoPay(c *gin.Context) {
// 创建本地订单
topUp := &model.TopUp{
UserId: id,
Amount: amount,
Money: payMoney,
TradeNo: merchantOrderId,
PaymentMethod: "waffo",
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
UserId: id,
Amount: amount,
Money: payMoney,
TradeNo: merchantOrderId,
PaymentMethod: model.PaymentMethodWaffo,
PaymentProvider: model.PaymentProviderWaffo,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := topUp.Insert(); err != nil {
log.Printf("Waffo 创建本地订单失败: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
sdk, err := getWaffoSDK()
if err != nil {
log.Printf("Waffo SDK 初始化失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo SDK 初始化失败 user_id=%d trade_no=%s error=%q", id, merchantOrderId, err.Error()))
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
c.JSON(200, gin.H{"message": "error", "data": "支付配置错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "支付配置错误"})
return
}
@@ -238,29 +269,29 @@ func RequestWaffoPay(c *gin.Context) {
}
resp, err := sdk.Order().Create(c.Request.Context(), createParams, nil)
if err != nil {
log.Printf("Waffo 创建订单失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建订单失败 user_id=%d trade_no=%s error=%q", id, merchantOrderId, err.Error()))
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
if !resp.IsSuccess() {
log.Printf("Waffo 创建订单业务失败: [%s] %s, 完整响应: %+v", resp.Code, resp.Message, resp)
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 创建订单业务失败 user_id=%d trade_no=%s code=%s message=%q response=%q", id, merchantOrderId, resp.Code, resp.Message, common.GetJsonString(resp)))
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
orderData := resp.GetData()
log.Printf("Waffo 订单创建成功 - 用户: %d, 订单: %s, 金额: %.2f", id, merchantOrderId, payMoney)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 充值订单创建成功 user_id=%d trade_no=%s amount=%d money=%.2f pay_method_type=%s pay_method_name=%q", id, merchantOrderId, req.Amount, payMoney, resolvedPayMethodType, resolvedPayMethodName))
paymentUrl := orderData.FetchRedirectURL()
if paymentUrl == "" {
paymentUrl = orderData.OrderAction
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"payment_url": paymentUrl,
@@ -287,16 +318,22 @@ type webhookSubscriptionInfo struct {
// WaffoWebhook 处理 Waffo 回调通知(支付/退款/订阅)
func WaffoWebhook(c *gin.Context) {
if !isWaffoWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("Waffo Webhook 读取 body 失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusBadRequest)
return
}
sdk, err := getWaffoSDK()
if err != nil {
log.Printf("Waffo Webhook SDK 初始化失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook SDK 初始化失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
@@ -304,17 +341,18 @@ func WaffoWebhook(c *gin.Context) {
wh := sdk.Webhook()
bodyStr := string(bodyBytes)
signature := c.GetHeader("X-SIGNATURE")
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, bodyStr))
// 验证请求签名
if !wh.VerifySignature(bodyStr, signature) {
log.Printf("Waffo webhook 签名验证失败")
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo webhook 验签失败 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, bodyStr))
c.AbortWithStatus(http.StatusBadRequest)
return
}
var event core.WebhookEvent
if err := common.Unmarshal(bodyBytes, &event); err != nil {
log.Printf("Waffo Webhook 解析失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook 解析失败 path=%q client_ip=%s error=%q body=%q", c.Request.RequestURI, c.ClientIP(), err.Error(), bodyStr))
sendWaffoWebhookResponse(c, wh, false, "invalid payload")
return
}
@@ -324,14 +362,14 @@ func WaffoWebhook(c *gin.Context) {
// 解析为扩展类型,区分普通支付和订阅支付
var payload webhookPayloadWithSubInfo
if err := common.Unmarshal(bodyBytes, &payload); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 支付回调载荷解析失败 event_type=%s client_ip=%s error=%q body=%q", event.EventType, c.ClientIP(), err.Error(), bodyStr))
sendWaffoWebhookResponse(c, wh, false, "invalid payment payload")
return
}
log.Printf("Waffo Webhook - EventType: %s, MerchantOrderId: %s, OrderStatus: %s",
event.EventType, payload.Result.MerchantOrderID, payload.Result.OrderStatus)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 验签并解析成功 event_type=%s merchant_order_id=%s order_status=%s client_ip=%s", event.EventType, payload.Result.MerchantOrderID, payload.Result.OrderStatus, c.ClientIP()))
handleWaffoPayment(c, wh, &payload.Result.PaymentNotificationResult)
default:
log.Printf("Waffo Webhook 未知事件: %s", event.EventType)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 忽略事件 event_type=%s client_ip=%s", event.EventType, c.ClientIP()))
sendWaffoWebhookResponse(c, wh, true, "")
}
}
@@ -339,13 +377,13 @@ func WaffoWebhook(c *gin.Context) {
// handleWaffoPayment 处理支付完成通知
func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.PaymentNotificationResult) {
if result.OrderStatus != "PAY_SUCCESS" {
log.Printf("Waffo 订单状态非成功: %s, 订单: %s", result.OrderStatus, result.MerchantOrderID)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP()))
// 终态失败订单标记为 failed,避免永远停在 pending
if result.MerchantOrderID != "" {
if topUp := model.GetTopUpByTradeNo(result.MerchantOrderID); topUp != nil &&
topUp.Status == common.TopUpStatusPending {
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentProviderWaffo, common.TopUpStatusFailed); err != nil &&
!errors.Is(err, model.ErrTopUpNotFound) &&
!errors.Is(err, model.ErrTopUpStatusInvalid) {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error()))
}
}
sendWaffoWebhookResponse(c, wh, true, "")
@@ -357,13 +395,13 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa
LockOrder(merchantOrderId)
defer UnlockOrder(merchantOrderId)
if err := model.RechargeWaffo(merchantOrderId); err != nil {
log.Printf("Waffo 充值处理失败: %v, 订单: %s", err, merchantOrderId)
if err := model.RechargeWaffo(merchantOrderId, c.ClientIP()); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 充值处理失败 trade_no=%s client_ip=%s error=%q", merchantOrderId, c.ClientIP(), err.Error()))
sendWaffoWebhookResponse(c, wh, false, err.Error())
return
}
log.Printf("Waffo 充值成功 - 订单: %s", merchantOrderId)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 充值成功 trade_no=%s client_ip=%s", merchantOrderId, c.ClientIP()))
sendWaffoWebhookResponse(c, wh, true, "")
}
+260
View File
@@ -0,0 +1,260 @@
package controller
import (
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"github.com/thanhpk/randstr"
)
type WaffoPancakePayRequest struct {
Amount int64 `json:"amount"`
}
func RequestWaffoPancakeAmount(c *gin.Context) {
var req WaffoPancakePayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < int64(setting.WaffoPancakeMinTopUp) {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", setting.WaffoPancakeMinTopUp)})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getWaffoPancakePayMoney(req.Amount, group)
if payMoney <= 0.01 {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "success", "data": fmt.Sprintf("%.2f", payMoney)})
}
func getWaffoPancakePayMoney(amount int64, group string) float64 {
dAmount := decimal.NewFromInt(amount)
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
dAmount = dAmount.Div(decimal.NewFromFloat(common.QuotaPerUnit))
}
topupGroupRatio := common.GetTopupGroupRatio(group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
discount := 1.0
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok && ds > 0 {
discount = ds
}
payMoney := dAmount.
Mul(decimal.NewFromFloat(setting.WaffoPancakeUnitPrice)).
Mul(decimal.NewFromFloat(topupGroupRatio)).
Mul(decimal.NewFromFloat(discount))
return payMoney.InexactFloat64()
}
func normalizeWaffoPancakeTopUpAmount(amount int64) int64 {
if operation_setting.GetQuotaDisplayType() != operation_setting.QuotaDisplayTypeTokens {
return amount
}
normalized := decimal.NewFromInt(amount).
Div(decimal.NewFromFloat(common.QuotaPerUnit)).
IntPart()
if normalized < 1 {
return 1
}
return normalized
}
func formatWaffoPancakeAmount(payMoney float64) string {
return decimal.NewFromFloat(payMoney).StringFixed(2)
}
func getWaffoPancakeBuyerEmail(user *model.User) string {
if user != nil && strings.TrimSpace(user.Email) != "" {
return user.Email
}
if user != nil {
return fmt.Sprintf("%d@new-api.local", user.Id)
}
return ""
}
func getWaffoPancakeReturnURL() string {
if strings.TrimSpace(setting.WaffoPancakeReturnURL) != "" {
return setting.WaffoPancakeReturnURL
}
return strings.TrimRight(system_setting.ServerAddress, "/") + "/console/topup?show_history=true"
}
func RequestWaffoPancakePay(c *gin.Context) {
if !setting.WaffoPancakeEnabled {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 支付未启用"})
return
}
currentWebhookKey := setting.WaffoPancakeWebhookPublicKey
if setting.WaffoPancakeSandbox {
currentWebhookKey = setting.WaffoPancakeWebhookTestKey
}
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" ||
strings.TrimSpace(currentWebhookKey) == "" ||
strings.TrimSpace(setting.WaffoPancakeStoreID) == "" ||
strings.TrimSpace(setting.WaffoPancakeProductID) == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 配置不完整"})
return
}
var req WaffoPancakePayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < int64(setting.WaffoPancakeMinTopUp) {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", setting.WaffoPancakeMinTopUp)})
return
}
id := c.GetInt("id")
user, err := model.GetUserById(id, false)
if err != nil || user == nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "用户不存在"})
return
}
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getWaffoPancakePayMoney(req.Amount, group)
if payMoney < 0.01 {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
topUp := &model.TopUp{
UserId: id,
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
Money: payMoney,
TradeNo: tradeNo,
PaymentMethod: model.PaymentMethodWaffoPancake,
PaymentProvider: model.PaymentProviderWaffoPancake,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := topUp.Insert(); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
expiresInSeconds := 45 * 60
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
StoreID: setting.WaffoPancakeStoreID,
ProductID: setting.WaffoPancakeProductID,
ProductType: "onetime",
Currency: strings.ToUpper(strings.TrimSpace(setting.WaffoPancakeCurrency)),
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
Amount: formatWaffoPancakeAmount(payMoney),
TaxIncluded: false,
TaxCategory: "saas",
},
BuyerEmail: getWaffoPancakeBuyerEmail(user),
SuccessURL: getWaffoPancakeReturnURL(),
ExpiresInSeconds: &expiresInSeconds,
})
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建结账会话失败 user_id=%d trade_no=%s error=%q", id, tradeNo, err.Error()))
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值订单创建成功 user_id=%d trade_no=%s session_id=%s amount=%d money=%.2f", id, tradeNo, session.SessionID, req.Amount, payMoney))
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": session.CheckoutURL,
"session_id": session.SessionID,
"expires_at": session.ExpiresAt,
"order_id": tradeNo,
},
})
}
func WaffoPancakeWebhook(c *gin.Context) {
if !isWaffoPancakeWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.String(http.StatusForbidden, "webhook disabled")
return
}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.String(http.StatusBadRequest, "bad request")
return
}
signature := c.GetHeader("X-Waffo-Signature")
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
event, err := service.VerifyConfiguredWaffoPancakeWebhook(string(bodyBytes), signature)
if err != nil {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签失败 path=%q client_ip=%s signature=%q body=%q error=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes), err.Error()))
c.String(http.StatusUnauthorized, "invalid signature")
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签成功 event_type=%s event_id=%s order_id=%s client_ip=%s", event.NormalizedEventType(), event.ID, event.Data.OrderID, c.ClientIP()))
if event.NormalizedEventType() != "order.completed" {
c.String(http.StatusOK, "OK")
return
}
tradeNo, err := service.ResolveWaffoPancakeTradeNo(event)
if err != nil {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 订单号映射失败 event_id=%s order_id=%s error=%q", event.ID, event.Data.OrderID, err.Error()))
c.String(http.StatusOK, "OK")
return
}
LockOrder(tradeNo)
defer UnlockOrder(tradeNo)
if err := model.RechargeWaffoPancake(tradeNo); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值处理失败 trade_no=%s event_id=%s order_id=%s client_ip=%s error=%q", tradeNo, event.ID, event.Data.OrderID, c.ClientIP(), err.Error()))
c.String(http.StatusInternalServerError, "retry")
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值成功 trade_no=%s event_id=%s order_id=%s client_ip=%s", tradeNo, event.ID, event.Data.OrderID, c.ClientIP()))
c.String(http.StatusOK, "OK")
}
+91
View File
@@ -0,0 +1,91 @@
package controller
import (
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/stretchr/testify/require"
)
func TestFormatWaffoPancakeAmount_UsesDisplayPriceString(t *testing.T) {
testCases := []struct {
name string
amount float64
expected string
}{
{name: "whole amount", amount: 29, expected: "29.00"},
{name: "decimal amount", amount: 29.9, expected: "29.90"},
{name: "round half up to cents", amount: 29.999, expected: "30.00"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.expected, formatWaffoPancakeAmount(tc.amount))
})
}
}
func TestGetWaffoPancakePayMoney(t *testing.T) {
originalUnitPrice := setting.WaffoPancakeUnitPrice
originalQuotaDisplayType := operation_setting.GetGeneralSetting().QuotaDisplayType
originalDiscounts := make(map[int]float64, len(operation_setting.GetPaymentSetting().AmountDiscount))
for k, v := range operation_setting.GetPaymentSetting().AmountDiscount {
originalDiscounts[k] = v
}
originalTopupGroupRatio := common.TopupGroupRatio2JSONString()
t.Cleanup(func() {
setting.WaffoPancakeUnitPrice = originalUnitPrice
operation_setting.GetGeneralSetting().QuotaDisplayType = originalQuotaDisplayType
operation_setting.GetPaymentSetting().AmountDiscount = originalDiscounts
require.NoError(t, common.UpdateTopupGroupRatioByJSONString(originalTopupGroupRatio))
})
setting.WaffoPancakeUnitPrice = 2.5
operation_setting.GetPaymentSetting().AmountDiscount = map[int]float64{
10: 0.8,
int(common.QuotaPerUnit * 3): 0.5,
20: 0,
}
require.NoError(t, common.UpdateTopupGroupRatioByJSONString(`{"default":1,"vip":1.2}`))
testCases := []struct {
name string
amount int64
group string
quotaDisplayType string
expected float64
}{
{
name: "currency display applies unit price group ratio and discount",
amount: 10,
group: "vip",
quotaDisplayType: operation_setting.QuotaDisplayTypeUSD,
expected: 24,
},
{
name: "tokens display converts quota to display units before pricing",
amount: int64(common.QuotaPerUnit * 3),
group: "vip",
quotaDisplayType: operation_setting.QuotaDisplayTypeTokens,
expected: 4.5,
},
{
name: "non-positive discount falls back to no discount",
amount: 20,
group: "default",
quotaDisplayType: operation_setting.QuotaDisplayTypeUSD,
expected: 50,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
operation_setting.GetGeneralSetting().QuotaDisplayType = tc.quotaDisplayType
actual := getWaffoPancakePayMoney(tc.amount, tc.group)
require.InDelta(t, tc.expected, actual, 0.000001)
})
}
}
+8 -4
View File
@@ -2,7 +2,6 @@ package controller
import (
"errors"
"fmt"
"net/http"
"strconv"
@@ -542,10 +541,15 @@ func AdminDisable2FA(c *gin.Context) {
return
}
// 记录操作日志
// 记录操作日志:管理员身份通过 admin_info 传递,避免在非管理员可见的日志内容中泄露。
adminId := c.GetInt("id")
model.RecordLog(userId, model.LogTypeManage,
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
adminName := c.GetString("username")
adminInfo := map[string]interface{}{
"admin_id": adminId,
"admin_username": adminName,
}
model.RecordLogWithAdminInfo(userId, model.LogTypeManage,
"管理员强制禁用了用户的两步验证", adminInfo)
c.JSON(http.StatusOK, gin.H{
"success": true,
+29 -6
View File
@@ -91,6 +91,7 @@ func Login(c *gin.Context) {
// setup session & cookies and then return user info
func setupLogin(user *model.User, c *gin.Context) {
model.UpdateUserLastLoginAt(user.Id)
session := sessions.Default(c)
session.Set("id", user.Id)
session.Set("username", user.Username)
@@ -891,6 +892,11 @@ func ManageUser(c *gin.Context) {
})
return
}
// 删除用户后,强制清理 Redis 中所有该用户令牌的缓存,
// 避免已缓存的令牌在 TTL 过期前仍能通过 TokenAuth 校验。
if err := model.InvalidateUserTokensCache(user.Id); err != nil {
common.SysLog(fmt.Sprintf("failed to invalidate tokens cache for user %d: %s", user.Id, err.Error()))
}
case "promote":
if myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserAdminCannotPromote)
@@ -913,6 +919,11 @@ func ManageUser(c *gin.Context) {
user.Role = common.RoleCommonUser
case "add_quota":
adminName := c.GetString("username")
adminId := c.GetInt("id")
adminInfo := map[string]interface{}{
"admin_id": adminId,
"admin_username": adminName,
}
switch req.Mode {
case "add":
if req.Value <= 0 {
@@ -923,8 +934,8 @@ func ManageUser(c *gin.Context) {
common.ApiError(c, err)
return
}
model.RecordLog(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员(%s)增加用户额度 %s", adminName, logger.LogQuota(req.Value)))
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员增加用户额度 %s", logger.LogQuota(req.Value)), adminInfo)
case "subtract":
if req.Value <= 0 {
common.ApiErrorI18n(c, i18n.MsgUserQuotaChangeZero)
@@ -934,16 +945,16 @@ func ManageUser(c *gin.Context) {
common.ApiError(c, err)
return
}
model.RecordLog(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员(%s)减少用户额度 %s", adminName, logger.LogQuota(req.Value)))
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员减少用户额度 %s", logger.LogQuota(req.Value)), adminInfo)
case "override":
oldQuota := user.Quota
if err := model.DB.Model(&model.User{}).Where("id = ?", user.Id).Update("quota", req.Value).Error; err != nil {
common.ApiError(c, err)
return
}
model.RecordLog(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员(%s)覆盖用户额度从 %s 为 %s", adminName, logger.LogQuota(oldQuota), logger.LogQuota(req.Value)))
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员覆盖用户额度从 %s 为 %s", logger.LogQuota(oldQuota), logger.LogQuota(req.Value)), adminInfo)
default:
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
return
@@ -959,6 +970,18 @@ func ManageUser(c *gin.Context) {
common.ApiError(c, err)
return
}
// 禁用 / 角色调整后,强制失效用户缓存与其全部令牌缓存,
// 避免在 Redis TTL 过期前仍使用旧状态(尤其是禁用后仍可发起请求的问题)。
// InvalidateUserCache 会让下一次 GetUserCache 从数据库重新加载,
// InvalidateUserTokensCache 则确保令牌侧的缓存也同步刷新。
if req.Action == "disable" || req.Action == "promote" || req.Action == "demote" {
if err := model.InvalidateUserCache(user.Id); err != nil {
common.SysLog(fmt.Sprintf("failed to invalidate user cache for user %d: %s", user.Id, err.Error()))
}
if err := model.InvalidateUserTokensCache(user.Id); err != nil {
common.SysLog(fmt.Sprintf("failed to invalidate tokens cache for user %d: %s", user.Id, err.Error()))
}
}
clearUser := model.User{
Role: user.Role,
Status: user.Status,