Merge remote-tracking branch 'origin/main' into codex/redeem-subscription
# Conflicts: # web/classic/src/components/topup/index.jsx
This commit is contained in:
+94
-14
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
@@ -233,6 +234,15 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
info.IsChannelTest = true
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
err = attachTestBillingRequestInput(info, request)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
@@ -460,7 +470,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
|
||||
if bodyErr := validateTestResponseBody(respBody, isStream); bodyErr != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: bodyErr,
|
||||
@@ -469,21 +479,11 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
}
|
||||
info.SetEstimatePromptTokens(usage.PromptTokens)
|
||||
|
||||
quota := 0
|
||||
if !priceData.UsePrice {
|
||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
||||
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
||||
if priceData.ModelRatio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
|
||||
}
|
||||
quota, tieredResult := settleTestQuota(info, priceData, usage)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
|
||||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||||
ChannelId: channel.Id,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
@@ -505,6 +505,50 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
}
|
||||
}
|
||||
|
||||
func attachTestBillingRequestInput(info *relaycommon.RelayInfo, request dto.Request) error {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
input, err := helper.BuildBillingExprRequestInputFromRequest(request, info.RequestHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info.BillingRequestInput = &input
|
||||
return nil
|
||||
}
|
||||
|
||||
func settleTestQuota(info *relaycommon.RelayInfo, priceData types.PriceData, usage *dto.Usage) (int, *billingexpr.TieredResult) {
|
||||
if usage != nil && info != nil && info.TieredBillingSnapshot != nil {
|
||||
isClaudeUsageSemantic := usage.UsageSemantic == "anthropic" || info.GetFinalRequestRelayFormat() == types.RelayFormatClaude
|
||||
usedVars := billingexpr.UsedVars(info.TieredBillingSnapshot.ExprString)
|
||||
if ok, quota, result := service.TryTieredSettle(info, service.BuildTieredTokenParams(usage, isClaudeUsageSemantic, usedVars)); ok {
|
||||
return quota, result
|
||||
}
|
||||
}
|
||||
|
||||
quota := 0
|
||||
if !priceData.UsePrice {
|
||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
||||
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
||||
if priceData.ModelRatio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
return quota, nil
|
||||
}
|
||||
|
||||
return int(priceData.ModelPrice * common.QuotaPerUnit), nil
|
||||
}
|
||||
|
||||
func buildTestLogOther(c *gin.Context, info *relaycommon.RelayInfo, priceData types.PriceData, usage *dto.Usage, tieredResult *billingexpr.TieredResult) map[string]interface{} {
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if tieredResult != nil {
|
||||
service.InjectTieredBillingInfo(other, info, tieredResult)
|
||||
}
|
||||
return other
|
||||
}
|
||||
|
||||
func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
|
||||
switch u := usageAny.(type) {
|
||||
case *dto.Usage:
|
||||
@@ -570,6 +614,42 @@ func detectErrorFromTestResponseBody(respBody []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateStreamTestResponseBody(respBody []byte) error {
|
||||
b := bytes.TrimSpace(respBody)
|
||||
if len(b) == 0 {
|
||||
return errors.New("stream response body is empty")
|
||||
}
|
||||
|
||||
for _, line := range bytes.Split(b, []byte{'\n'}) {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
|
||||
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("stream response body does not contain a valid stream event")
|
||||
}
|
||||
|
||||
func validateTestResponseBody(respBody []byte, isStream bool) error {
|
||||
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
|
||||
return bodyErr
|
||||
}
|
||||
if isStream {
|
||||
return validateStreamTestResponseBody(respBody)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldUseStreamForAutomaticChannelTest(channel *model.Channel) bool {
|
||||
return channel != nil && channel.Type == constant.ChannelTypeCodex
|
||||
}
|
||||
|
||||
func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
|
||||
if len(jsonBytes) == 0 {
|
||||
return ""
|
||||
@@ -822,7 +902,7 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, "", "", false)
|
||||
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSettleTestQuotaUsesTieredBilling(t *testing.T) {
|
||||
info := &relaycommon.RelayInfo{
|
||||
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
|
||||
BillingMode: "tiered_expr",
|
||||
ExprString: `param("stream") == true ? tier("stream", p * 3) : tier("base", p * 2)`,
|
||||
ExprHash: billingexpr.ExprHashString(`param("stream") == true ? tier("stream", p * 3) : tier("base", p * 2)`),
|
||||
GroupRatio: 1,
|
||||
EstimatedTier: "stream",
|
||||
QuotaPerUnit: common.QuotaPerUnit,
|
||||
ExprVersion: 1,
|
||||
},
|
||||
BillingRequestInput: &billingexpr.RequestInput{
|
||||
Body: []byte(`{"stream":true}`),
|
||||
},
|
||||
}
|
||||
|
||||
quota, result := settleTestQuota(info, types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 2,
|
||||
}, &dto.Usage{
|
||||
PromptTokens: 1000,
|
||||
})
|
||||
|
||||
require.Equal(t, 1500, quota)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "stream", result.MatchedTier)
|
||||
}
|
||||
|
||||
func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
TieredBillingSnapshot: &billingexpr.BillingSnapshot{
|
||||
BillingMode: "tiered_expr",
|
||||
ExprString: `tier("base", p * 2)`,
|
||||
},
|
||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||
}
|
||||
priceData := types.PriceData{
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
}
|
||||
usage := &dto.Usage{
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 12,
|
||||
},
|
||||
}
|
||||
|
||||
other := buildTestLogOther(ctx, info, priceData, usage, &billingexpr.TieredResult{
|
||||
MatchedTier: "base",
|
||||
})
|
||||
|
||||
require.Equal(t, "tiered_expr", other["billing_mode"])
|
||||
require.Equal(t, "base", other["matched_tier"])
|
||||
require.NotEmpty(t, other["expr_b64"])
|
||||
}
|
||||
@@ -32,6 +32,26 @@ const (
|
||||
channelUpstreamModelUpdateNotifyMaxFailedChannelIDs = 10
|
||||
)
|
||||
|
||||
var channelUpstreamModelUpdateSelectFields = []string{
|
||||
"id",
|
||||
"name",
|
||||
"type",
|
||||
"key",
|
||||
"status",
|
||||
"base_url",
|
||||
"models",
|
||||
"model_mapping",
|
||||
"settings",
|
||||
"setting",
|
||||
"other",
|
||||
"group",
|
||||
"priority",
|
||||
"weight",
|
||||
"tag",
|
||||
"channel_info",
|
||||
"header_override",
|
||||
}
|
||||
|
||||
var (
|
||||
channelUpstreamModelUpdateTaskOnce sync.Once
|
||||
channelUpstreamModelUpdateTaskRunning atomic.Bool
|
||||
@@ -521,7 +541,7 @@ func runChannelUpstreamModelUpdateTaskOnce() {
|
||||
for {
|
||||
var channels []*model.Channel
|
||||
query := model.DB.
|
||||
Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
|
||||
Select(channelUpstreamModelUpdateSelectFields).
|
||||
Where("status = ?", common.ChannelStatusEnabled).
|
||||
Order("id asc").
|
||||
Limit(channelUpstreamModelUpdateTaskBatchSize)
|
||||
@@ -814,7 +834,7 @@ func collectPendingApplyUpstreamModelChanges(settings dto.ChannelOtherSettings)
|
||||
func findEnabledChannelsAfterID(lastID int, batchSize int) ([]*model.Channel, error) {
|
||||
var channels []*model.Channel
|
||||
query := model.DB.
|
||||
Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
|
||||
Select(channelUpstreamModelUpdateSelectFields).
|
||||
Where("status = ?", common.ChannelStatusEnabled).
|
||||
Order("id asc").
|
||||
Limit(batchSize)
|
||||
|
||||
@@ -81,6 +81,10 @@ func TestCollectPendingApplyUpstreamModelChanges(t *testing.T) {
|
||||
require.Equal(t, []string{"old-model"}, pendingRemoveModels)
|
||||
}
|
||||
|
||||
func TestChannelUpstreamModelUpdateSelectFieldsIncludeModelMapping(t *testing.T) {
|
||||
require.Contains(t, channelUpstreamModelUpdateSelectFields, "model_mapping")
|
||||
}
|
||||
|
||||
func TestNormalizeChannelModelMapping(t *testing.T) {
|
||||
modelMapping := `{
|
||||
" alias-model ": " upstream-model ",
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DiscordResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type DiscordUser struct {
|
||||
UID string `json:"id"`
|
||||
ID string `json:"username"`
|
||||
Name string `json:"global_name"`
|
||||
}
|
||||
|
||||
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var discordResponse DiscordResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&discordResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordResponse.AccessToken == "" {
|
||||
common.SysError("Discord 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("Discord 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var discordUser DiscordUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&discordUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if discordUser.UID == "" || discordUser.ID == "" {
|
||||
common.SysError("Discord 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &discordUser, nil
|
||||
}
|
||||
|
||||
func DiscordOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
DiscordBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
err := user.FillUserByDiscordId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if discordUser.ID != "" {
|
||||
user.Username = discordUser.ID
|
||||
} else {
|
||||
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if discordUser.Name != "" {
|
||||
user.DisplayName = discordUser.Name
|
||||
} else {
|
||||
user.DisplayName = "Discord User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func DiscordBind(c *gin.Context) {
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Discord 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.DiscordId = discordUser.UID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GitHubOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type GitHubUser struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse GitHubOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
var githubUser GitHubUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&githubUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if githubUser.Login == "" {
|
||||
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||
}
|
||||
return &githubUser, nil
|
||||
}
|
||||
|
||||
func GitHubOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
GitHubBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
// IsGitHubIdAlreadyTaken is unscoped
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
// FillUserByGitHubId is scoped
|
||||
err := user.FillUserByGitHubId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// if user.Id == 0 , user has been deleted
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if githubUser.Name != "" {
|
||||
user.DisplayName = githubUser.Name
|
||||
} else {
|
||||
user.DisplayName = "GitHub User"
|
||||
}
|
||||
user.Email = githubUser.Email
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func GitHubBind(c *gin.Context) {
|
||||
if !common.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 GitHub 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.GitHubId = githubUser.Login
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LinuxdoUser struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func LinuxDoBind(c *gin.Context) {
|
||||
if !common.LinuxDOOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Linux DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
|
||||
}
|
||||
|
||||
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Linux DO 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
|
||||
func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("invalid code")
|
||||
}
|
||||
|
||||
// Get access token using Basic auth
|
||||
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
|
||||
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
|
||||
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||
|
||||
// Get redirect URI from request
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", redirectURI)
|
||||
|
||||
req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", basicAuth)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to connect to Linux DO server")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var tokenRes struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tokenRes.AccessToken == "" {
|
||||
return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
|
||||
}
|
||||
|
||||
// Get user info
|
||||
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
|
||||
req, err = http.NewRequest("GET", userEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get user info from Linux DO")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
|
||||
var linuxdoUser LinuxdoUser
|
||||
if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if linuxdoUser.Id == 0 {
|
||||
return nil, errors.New("invalid user info returned")
|
||||
}
|
||||
|
||||
return &linuxdoUser, nil
|
||||
}
|
||||
|
||||
func LinuxdoOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
|
||||
errorCode := c.Query("error")
|
||||
if errorCode != "" {
|
||||
errorDescription := c.Query("error_description")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": errorDescription,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LinuxDoBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.LinuxDOOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Linux DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
|
||||
}
|
||||
|
||||
// Check if user exists
|
||||
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
|
||||
err := user.FillUserByLinuxDOId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
@@ -61,6 +61,7 @@ func GetStatus(c *gin.Context) {
|
||||
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"theme": system_setting.GetThemeSettings().Frontend,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
|
||||
+3
-5
@@ -15,9 +15,9 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/channel/minimax"
|
||||
"github.com/QuantumNous/new-api/relay/channel/moonshot"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
@@ -134,8 +134,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
for allowModel, _ := range tokenModelLimit {
|
||||
if !acceptUnsetRatioModel {
|
||||
_, _, exist := ratio_setting.GetModelRatioOrPrice(allowModel)
|
||||
if !exist {
|
||||
if !helper.HasModelBillingConfig(allowModel) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -182,8 +181,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
for _, modelName := range models {
|
||||
if !acceptUnsetRatioModel {
|
||||
_, _, exist := ratio_setting.GetModelRatioOrPrice(modelName)
|
||||
if !exist {
|
||||
if !helper.HasModelBillingConfig(modelName) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/config"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type listModelsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data []dto.OpenAIModels `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
initModelListColumnNames(t)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
common.UsingSQLite = true
|
||||
common.UsingMySQL = false
|
||||
common.UsingPostgreSQL = false
|
||||
common.RedisEnabled = false
|
||||
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
model.DB = db
|
||||
model.LOG_DB = db
|
||||
|
||||
require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{}))
|
||||
|
||||
t.Cleanup(func() {
|
||||
sqlDB, err := db.DB()
|
||||
if err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func initModelListColumnNames(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
originalIsMasterNode := common.IsMasterNode
|
||||
originalSQLitePath := common.SQLitePath
|
||||
originalUsingSQLite := common.UsingSQLite
|
||||
originalUsingMySQL := common.UsingMySQL
|
||||
originalUsingPostgreSQL := common.UsingPostgreSQL
|
||||
originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN")
|
||||
defer func() {
|
||||
common.IsMasterNode = originalIsMasterNode
|
||||
common.SQLitePath = originalSQLitePath
|
||||
common.UsingSQLite = originalUsingSQLite
|
||||
common.UsingMySQL = originalUsingMySQL
|
||||
common.UsingPostgreSQL = originalUsingPostgreSQL
|
||||
if hadSQLDSN {
|
||||
require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN))
|
||||
} else {
|
||||
require.NoError(t, os.Unsetenv("SQL_DSN"))
|
||||
}
|
||||
}()
|
||||
|
||||
common.IsMasterNode = false
|
||||
common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
||||
common.UsingSQLite = false
|
||||
common.UsingMySQL = false
|
||||
common.UsingPostgreSQL = false
|
||||
require.NoError(t, os.Setenv("SQL_DSN", "local"))
|
||||
|
||||
require.NoError(t, model.InitDB())
|
||||
if model.DB != nil {
|
||||
sqlDB, err := model.DB.DB()
|
||||
if err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) {
|
||||
t.Helper()
|
||||
|
||||
saved := map[string]string{}
|
||||
require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
|
||||
if strings.HasPrefix(key, "billing_setting.") {
|
||||
saved[key] = value
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
|
||||
model.InvalidatePricingCache()
|
||||
})
|
||||
|
||||
modeBytes, err := common.Marshal(modes)
|
||||
require.NoError(t, err)
|
||||
exprBytes, err := common.Marshal(exprs)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
|
||||
"billing_setting.billing_mode": string(modeBytes),
|
||||
"billing_setting.billing_expr": string(exprBytes),
|
||||
}))
|
||||
model.InvalidatePricingCache()
|
||||
}
|
||||
|
||||
func withSelfUseModeDisabled(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
original := operation_setting.SelfUseModeEnabled
|
||||
operation_setting.SelfUseModeEnabled = false
|
||||
t.Cleanup(func() {
|
||||
operation_setting.SelfUseModeEnabled = original
|
||||
})
|
||||
}
|
||||
|
||||
func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} {
|
||||
t.Helper()
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
var payload listModelsResponse
|
||||
require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))
|
||||
require.True(t, payload.Success)
|
||||
require.Equal(t, "list", payload.Object)
|
||||
|
||||
ids := make(map[string]struct{}, len(payload.Data))
|
||||
for _, item := range payload.Data {
|
||||
ids[item.Id] = struct{}{}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
|
||||
byName := make(map[string]model.Pricing, len(pricings))
|
||||
for _, pricing := range pricings {
|
||||
byName[pricing.ModelName] = pricing
|
||||
}
|
||||
return byName
|
||||
}
|
||||
|
||||
func TestListModelsIncludesTieredBillingModel(t *testing.T) {
|
||||
withSelfUseModeDisabled(t)
|
||||
withTieredBillingConfig(t, map[string]string{
|
||||
"zz-tiered-visible-model": "tiered_expr",
|
||||
"zz-tiered-empty-expr-model": "tiered_expr",
|
||||
"zz-tiered-missing-expr-model": "tiered_expr",
|
||||
}, map[string]string{
|
||||
"zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
|
||||
"zz-tiered-empty-expr-model": " ",
|
||||
})
|
||||
|
||||
db := setupModelListControllerTestDB(t)
|
||||
require.NoError(t, db.Create(&model.User{
|
||||
Id: 1001,
|
||||
Username: "model-list-user",
|
||||
Password: "password",
|
||||
Group: "default",
|
||||
Status: common.UserStatusEnabled,
|
||||
}).Error)
|
||||
require.NoError(t, db.Create(&[]model.Ability{
|
||||
{Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true},
|
||||
{Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true},
|
||||
{Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true},
|
||||
{Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true},
|
||||
}).Error)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
ctx.Set("id", 1001)
|
||||
|
||||
ListModels(ctx, constant.ChannelTypeOpenAI)
|
||||
|
||||
ids := decodeListModelsResponse(t, recorder)
|
||||
require.Contains(t, ids, "zz-tiered-visible-model")
|
||||
require.NotContains(t, ids, "zz-tiered-empty-expr-model")
|
||||
require.NotContains(t, ids, "zz-tiered-missing-expr-model")
|
||||
require.NotContains(t, ids, "zz-unpriced-model")
|
||||
|
||||
pricingByName := pricingByModelName(model.GetPricing())
|
||||
visiblePricing, ok := pricingByName["zz-tiered-visible-model"]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "tiered_expr", visiblePricing.BillingMode)
|
||||
require.NotEmpty(t, visiblePricing.BillingExpr)
|
||||
|
||||
emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"]
|
||||
require.True(t, ok)
|
||||
require.Empty(t, emptyExprPricing.BillingMode)
|
||||
require.Empty(t, emptyExprPricing.BillingExpr)
|
||||
|
||||
missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"]
|
||||
require.True(t, ok)
|
||||
require.Empty(t, missingExprPricing.BillingMode)
|
||||
require.Empty(t, missingExprPricing.BillingExpr)
|
||||
}
|
||||
|
||||
func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
|
||||
withSelfUseModeDisabled(t)
|
||||
withTieredBillingConfig(t, map[string]string{
|
||||
"zz-token-tiered-visible-model": "tiered_expr",
|
||||
"zz-token-tiered-empty-expr-model": "tiered_expr",
|
||||
"zz-token-tiered-missing-expr-model": "tiered_expr",
|
||||
}, map[string]string{
|
||||
"zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
|
||||
"zz-token-tiered-empty-expr-model": "",
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true)
|
||||
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{
|
||||
"zz-token-tiered-visible-model": true,
|
||||
"zz-token-tiered-empty-expr-model": true,
|
||||
"zz-token-tiered-missing-expr-model": true,
|
||||
"zz-token-unpriced-model": true,
|
||||
})
|
||||
|
||||
ListModels(ctx, constant.ChannelTypeOpenAI)
|
||||
|
||||
ids := decodeListModelsResponse(t, recorder)
|
||||
require.Contains(t, ids, "zz-token-tiered-visible-model")
|
||||
require.NotContains(t, ids, "zz-token-tiered-empty-expr-model")
|
||||
require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
|
||||
require.NotContains(t, ids, "zz-token-unpriced-model")
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OidcResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type OidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oidcResponse OidcResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcResponse.AccessToken == "" {
|
||||
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var oidcUser OidcUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &oidcUser, nil
|
||||
}
|
||||
|
||||
func OidcAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
OidcBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
err := user.FillUserByOidcId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Email = oidcUser.Email
|
||||
if oidcUser.PreferredUsername != "" {
|
||||
user.Username = oidcUser.PreferredUsername
|
||||
} else {
|
||||
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if oidcUser.Name != "" {
|
||||
user.DisplayName = oidcUser.Name
|
||||
} else {
|
||||
user.DisplayName = "OIDC User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func OidcBind(c *gin.Context) {
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 OIDC 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
+20
-2
@@ -27,6 +27,15 @@ var completionRatioMetaOptionKeys = []string{
|
||||
"AudioCompletionRatio",
|
||||
}
|
||||
|
||||
func isVisiblePublicKeyOption(key string) bool {
|
||||
switch key {
|
||||
case "WaffoPancakeWebhookPublicKey", "WaffoPancakeWebhookTestKey":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func collectModelNamesFromOptionValue(raw string, modelNames map[string]struct{}) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return
|
||||
@@ -66,11 +75,12 @@ func GetOptions(c *gin.Context) {
|
||||
common.OptionMapRWMutex.Lock()
|
||||
for k, v := range common.OptionMap {
|
||||
value := common.Interface2String(v)
|
||||
if strings.HasSuffix(k, "Token") ||
|
||||
isSensitiveKey := strings.HasSuffix(k, "Token") ||
|
||||
strings.HasSuffix(k, "Secret") ||
|
||||
strings.HasSuffix(k, "Key") ||
|
||||
strings.HasSuffix(k, "secret") ||
|
||||
strings.HasSuffix(k, "api_key") {
|
||||
strings.HasSuffix(k, "api_key")
|
||||
if isSensitiveKey && !isVisiblePublicKeyOption(k) {
|
||||
continue
|
||||
}
|
||||
options = append(options, &model.Option{
|
||||
@@ -188,6 +198,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "theme.frontend":
|
||||
if option.Value != "default" && option.Value != "classic" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的主题值,可选值:default(新版前端)、classic(经典前端)",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = ratio_setting.CheckGroupRatio(option.Value.(string))
|
||||
if err != nil {
|
||||
|
||||
@@ -36,6 +36,10 @@ func PasskeyRegisterBegin(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !requirePasskeyRegistrationVerification(c, user.Id) {
|
||||
return
|
||||
}
|
||||
|
||||
credential, err := model.GetPasskeyByUserID(user.Id)
|
||||
if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) {
|
||||
common.ApiError(c, err)
|
||||
@@ -96,6 +100,10 @@ func PasskeyRegisterFinish(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !requirePasskeyRegistrationVerification(c, user.Id) {
|
||||
return
|
||||
}
|
||||
|
||||
wa, err := passkeysvc.BuildWebAuthn(c.Request)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -151,6 +159,10 @@ func PasskeyDelete(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !requirePasskeyDeleteVerification(c, user.Id) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := model.DeletePasskeyByUserID(user.Id); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
@@ -474,6 +486,7 @@ func PasskeyVerifyFinish(c *gin.Context) {
|
||||
// Mark passkey as ready; /api/verify will convert this into the final secure verification session.
|
||||
session.Set(PasskeyReadySessionKey, time.Now().Unix())
|
||||
session.Delete(SecureVerificationSessionKey)
|
||||
session.Delete(secureVerificationMethodSessionKey)
|
||||
if err := session.Save(); err != nil {
|
||||
common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
|
||||
return
|
||||
@@ -504,3 +517,60 @@ func getSessionUser(c *gin.Context) (*model.User, error) {
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func requirePasskeyRegistrationVerification(c *gin.Context, userID int) bool {
|
||||
twoFA, err := model.GetTwoFAByUserId(userID)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return false
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
return true
|
||||
}
|
||||
return requireSecureVerificationMethod(c, secureVerificationMethod2FA)
|
||||
}
|
||||
|
||||
func requirePasskeyDeleteVerification(c *gin.Context, userID int) bool {
|
||||
twoFA, err := model.GetTwoFAByUserId(userID)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return false
|
||||
}
|
||||
if twoFA != nil && twoFA.IsEnabled {
|
||||
return requireSecureVerificationMethod(c, secureVerificationMethod2FA)
|
||||
}
|
||||
|
||||
_, err = model.GetPasskeyByUserID(userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, model.ErrPasskeyNotFound) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该用户尚未绑定 Passkey",
|
||||
})
|
||||
return false
|
||||
}
|
||||
common.ApiError(c, err)
|
||||
return false
|
||||
}
|
||||
|
||||
return requireSecureVerificationMethod(c, secureVerificationMethodPasskey)
|
||||
}
|
||||
|
||||
func requireSecureVerificationMethod(c *gin.Context, method string) bool {
|
||||
session := sessions.Default(c)
|
||||
verifiedAt, ok := session.Get(SecureVerificationSessionKey).(int64)
|
||||
if !ok || time.Now().Unix()-verifiedAt >= SecureVerificationTimeout {
|
||||
session.Delete(SecureVerificationSessionKey)
|
||||
session.Delete(secureVerificationMethodSessionKey)
|
||||
_ = session.Save()
|
||||
common.ApiErrorMsg(c, "请先完成安全验证")
|
||||
return false
|
||||
}
|
||||
|
||||
if verifiedMethod, ok := session.Get(secureVerificationMethodSessionKey).(string); !ok || verifiedMethod != method {
|
||||
common.ApiErrorMsg(c, "请先完成对应的安全验证")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
)
|
||||
|
||||
func isStripeTopUpEnabled() bool {
|
||||
return strings.TrimSpace(setting.StripeApiSecret) != "" &&
|
||||
strings.TrimSpace(setting.StripeWebhookSecret) != "" &&
|
||||
strings.TrimSpace(setting.StripePriceId) != ""
|
||||
}
|
||||
|
||||
func isStripeWebhookConfigured() bool {
|
||||
return strings.TrimSpace(setting.StripeWebhookSecret) != ""
|
||||
}
|
||||
|
||||
func isStripeWebhookEnabled() bool {
|
||||
return isStripeTopUpEnabled()
|
||||
}
|
||||
|
||||
func isCreemTopUpEnabled() bool {
|
||||
products := strings.TrimSpace(setting.CreemProducts)
|
||||
return strings.TrimSpace(setting.CreemApiKey) != "" &&
|
||||
products != "" &&
|
||||
products != "[]"
|
||||
}
|
||||
|
||||
func isCreemWebhookConfigured() bool {
|
||||
return strings.TrimSpace(setting.CreemWebhookSecret) != ""
|
||||
}
|
||||
|
||||
func isCreemWebhookEnabled() bool {
|
||||
return isCreemTopUpEnabled() && isCreemWebhookConfigured()
|
||||
}
|
||||
|
||||
func isWaffoTopUpEnabled() bool {
|
||||
if !setting.WaffoEnabled {
|
||||
return false
|
||||
}
|
||||
|
||||
return isWaffoWebhookConfigured()
|
||||
}
|
||||
|
||||
func isWaffoWebhookConfigured() bool {
|
||||
if setting.WaffoSandbox {
|
||||
return strings.TrimSpace(setting.WaffoSandboxApiKey) != "" &&
|
||||
strings.TrimSpace(setting.WaffoSandboxPrivateKey) != "" &&
|
||||
strings.TrimSpace(setting.WaffoSandboxPublicCert) != ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(setting.WaffoApiKey) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPrivateKey) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPublicCert) != ""
|
||||
}
|
||||
|
||||
func isWaffoWebhookEnabled() bool {
|
||||
return isWaffoTopUpEnabled()
|
||||
}
|
||||
|
||||
func isWaffoPancakeTopUpEnabled() bool {
|
||||
if !setting.WaffoPancakeEnabled {
|
||||
return false
|
||||
}
|
||||
|
||||
return isWaffoPancakeWebhookConfigured() &&
|
||||
strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPancakePrivateKey) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPancakeStoreID) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPancakeProductID) != ""
|
||||
}
|
||||
|
||||
func isWaffoPancakeWebhookConfigured() bool {
|
||||
currentWebhookKey := strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
|
||||
if setting.WaffoPancakeSandbox {
|
||||
currentWebhookKey = strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
|
||||
}
|
||||
|
||||
return currentWebhookKey != ""
|
||||
}
|
||||
|
||||
func isWaffoPancakeWebhookEnabled() bool {
|
||||
return isWaffoPancakeTopUpEnabled()
|
||||
}
|
||||
|
||||
func isEpayTopUpEnabled() bool {
|
||||
return isEpayWebhookConfigured() && len(operation_setting.PayMethods) > 0
|
||||
}
|
||||
|
||||
func isEpayWebhookConfigured() bool {
|
||||
return strings.TrimSpace(operation_setting.PayAddress) != "" &&
|
||||
strings.TrimSpace(operation_setting.EpayId) != "" &&
|
||||
strings.TrimSpace(operation_setting.EpayKey) != ""
|
||||
}
|
||||
|
||||
func isEpayWebhookEnabled() bool {
|
||||
return isEpayTopUpEnabled()
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStripeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
originalAPISecret := setting.StripeApiSecret
|
||||
originalWebhookSecret := setting.StripeWebhookSecret
|
||||
originalPriceID := setting.StripePriceId
|
||||
t.Cleanup(func() {
|
||||
setting.StripeApiSecret = originalAPISecret
|
||||
setting.StripeWebhookSecret = originalWebhookSecret
|
||||
setting.StripePriceId = originalPriceID
|
||||
})
|
||||
|
||||
setting.StripeWebhookSecret = ""
|
||||
setting.StripeApiSecret = "sk_test_123"
|
||||
setting.StripePriceId = "price_123"
|
||||
require.False(t, isStripeWebhookEnabled())
|
||||
|
||||
setting.StripeWebhookSecret = "whsec_test"
|
||||
require.True(t, isStripeWebhookEnabled())
|
||||
|
||||
setting.StripePriceId = ""
|
||||
require.False(t, isStripeWebhookEnabled())
|
||||
}
|
||||
|
||||
func TestCreemWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
originalAPIKey := setting.CreemApiKey
|
||||
originalProducts := setting.CreemProducts
|
||||
originalWebhookSecret := setting.CreemWebhookSecret
|
||||
t.Cleanup(func() {
|
||||
setting.CreemApiKey = originalAPIKey
|
||||
setting.CreemProducts = originalProducts
|
||||
setting.CreemWebhookSecret = originalWebhookSecret
|
||||
})
|
||||
|
||||
setting.CreemWebhookSecret = ""
|
||||
setting.CreemApiKey = "creem_api_key"
|
||||
setting.CreemProducts = `[{"productId":"prod_123"}]`
|
||||
require.False(t, isCreemWebhookEnabled())
|
||||
|
||||
setting.CreemWebhookSecret = "creem_secret"
|
||||
require.True(t, isCreemWebhookEnabled())
|
||||
|
||||
setting.CreemProducts = "[]"
|
||||
require.False(t, isCreemWebhookEnabled())
|
||||
}
|
||||
|
||||
func TestWaffoWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
originalEnabled := setting.WaffoEnabled
|
||||
originalSandbox := setting.WaffoSandbox
|
||||
originalAPIKey := setting.WaffoApiKey
|
||||
originalPrivateKey := setting.WaffoPrivateKey
|
||||
originalPublicCert := setting.WaffoPublicCert
|
||||
originalSandboxAPIKey := setting.WaffoSandboxApiKey
|
||||
originalSandboxPrivateKey := setting.WaffoSandboxPrivateKey
|
||||
originalSandboxPublicCert := setting.WaffoSandboxPublicCert
|
||||
t.Cleanup(func() {
|
||||
setting.WaffoEnabled = originalEnabled
|
||||
setting.WaffoSandbox = originalSandbox
|
||||
setting.WaffoApiKey = originalAPIKey
|
||||
setting.WaffoPrivateKey = originalPrivateKey
|
||||
setting.WaffoPublicCert = originalPublicCert
|
||||
setting.WaffoSandboxApiKey = originalSandboxAPIKey
|
||||
setting.WaffoSandboxPrivateKey = originalSandboxPrivateKey
|
||||
setting.WaffoSandboxPublicCert = originalSandboxPublicCert
|
||||
})
|
||||
|
||||
setting.WaffoEnabled = true
|
||||
setting.WaffoSandbox = false
|
||||
setting.WaffoApiKey = ""
|
||||
setting.WaffoPrivateKey = "private"
|
||||
setting.WaffoPublicCert = "public"
|
||||
require.False(t, isWaffoWebhookEnabled())
|
||||
|
||||
setting.WaffoApiKey = "api"
|
||||
require.True(t, isWaffoWebhookEnabled())
|
||||
|
||||
setting.WaffoEnabled = false
|
||||
require.False(t, isWaffoWebhookEnabled())
|
||||
|
||||
setting.WaffoEnabled = true
|
||||
setting.WaffoSandbox = true
|
||||
setting.WaffoSandboxApiKey = ""
|
||||
setting.WaffoSandboxPrivateKey = "sandbox_private"
|
||||
setting.WaffoSandboxPublicCert = "sandbox_public"
|
||||
require.False(t, isWaffoWebhookEnabled())
|
||||
|
||||
setting.WaffoSandboxApiKey = "sandbox_api"
|
||||
require.True(t, isWaffoWebhookEnabled())
|
||||
}
|
||||
|
||||
func TestWaffoPancakeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
originalEnabled := setting.WaffoPancakeEnabled
|
||||
originalSandbox := setting.WaffoPancakeSandbox
|
||||
originalMerchantID := setting.WaffoPancakeMerchantID
|
||||
originalPrivateKey := setting.WaffoPancakePrivateKey
|
||||
originalWebhookPublicKey := setting.WaffoPancakeWebhookPublicKey
|
||||
originalWebhookTestKey := setting.WaffoPancakeWebhookTestKey
|
||||
originalStoreID := setting.WaffoPancakeStoreID
|
||||
originalProductID := setting.WaffoPancakeProductID
|
||||
t.Cleanup(func() {
|
||||
setting.WaffoPancakeEnabled = originalEnabled
|
||||
setting.WaffoPancakeSandbox = originalSandbox
|
||||
setting.WaffoPancakeMerchantID = originalMerchantID
|
||||
setting.WaffoPancakePrivateKey = originalPrivateKey
|
||||
setting.WaffoPancakeWebhookPublicKey = originalWebhookPublicKey
|
||||
setting.WaffoPancakeWebhookTestKey = originalWebhookTestKey
|
||||
setting.WaffoPancakeStoreID = originalStoreID
|
||||
setting.WaffoPancakeProductID = originalProductID
|
||||
})
|
||||
|
||||
setting.WaffoPancakeEnabled = true
|
||||
setting.WaffoPancakeSandbox = false
|
||||
setting.WaffoPancakeMerchantID = "merchant"
|
||||
setting.WaffoPancakePrivateKey = "private"
|
||||
setting.WaffoPancakeStoreID = "store"
|
||||
setting.WaffoPancakeProductID = "product"
|
||||
setting.WaffoPancakeWebhookPublicKey = ""
|
||||
require.False(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeWebhookPublicKey = "public"
|
||||
require.True(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeEnabled = false
|
||||
require.False(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeEnabled = true
|
||||
setting.WaffoPancakeSandbox = true
|
||||
setting.WaffoPancakeWebhookTestKey = ""
|
||||
require.False(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeWebhookTestKey = "test_public"
|
||||
require.True(t, isWaffoPancakeWebhookEnabled())
|
||||
}
|
||||
|
||||
func TestEpayWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
originalPayAddress := operation_setting.PayAddress
|
||||
originalEpayID := operation_setting.EpayId
|
||||
originalEpayKey := operation_setting.EpayKey
|
||||
originalPayMethods := operation_setting.PayMethods
|
||||
t.Cleanup(func() {
|
||||
operation_setting.PayAddress = originalPayAddress
|
||||
operation_setting.EpayId = originalEpayID
|
||||
operation_setting.EpayKey = originalEpayKey
|
||||
operation_setting.PayMethods = originalPayMethods
|
||||
})
|
||||
|
||||
operation_setting.PayAddress = "https://pay.example.com"
|
||||
operation_setting.EpayId = "epay_id"
|
||||
operation_setting.EpayKey = ""
|
||||
operation_setting.PayMethods = []map[string]string{{"type": "alipay"}}
|
||||
require.False(t, isEpayWebhookEnabled())
|
||||
|
||||
operation_setting.EpayKey = "epay_key"
|
||||
require.True(t, isEpayWebhookEnabled())
|
||||
|
||||
operation_setting.PayMethods = nil
|
||||
require.False(t, isEpayWebhookEnabled())
|
||||
}
|
||||
+161
-46
@@ -21,14 +21,16 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/billing_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
defaultEndpoint = "/api/pricing"
|
||||
maxConcurrentFetches = 8
|
||||
maxRatioConfigBytes = 10 << 20 // 10MB
|
||||
floatEpsilon = 1e-9
|
||||
@@ -59,7 +61,29 @@ func valuesEqual(a, b interface{}) bool {
|
||||
return a == b
|
||||
}
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
var pricingSyncFields = []string{
|
||||
"model_ratio",
|
||||
"completion_ratio",
|
||||
"cache_ratio",
|
||||
"create_cache_ratio",
|
||||
"image_ratio",
|
||||
"audio_ratio",
|
||||
"audio_completion_ratio",
|
||||
"model_price",
|
||||
billing_setting.BillingModeField,
|
||||
billing_setting.BillingExprField,
|
||||
}
|
||||
|
||||
var numericPricingSyncFields = map[string]bool{
|
||||
"model_ratio": true,
|
||||
"completion_ratio": true,
|
||||
"cache_ratio": true,
|
||||
"create_cache_ratio": true,
|
||||
"image_ratio": true,
|
||||
"audio_ratio": true,
|
||||
"audio_completion_ratio": true,
|
||||
"model_price": true,
|
||||
}
|
||||
|
||||
type upstreamResult struct {
|
||||
Name string `json:"name"`
|
||||
@@ -67,6 +91,54 @@ type upstreamResult struct {
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
func valueMap(value any) map[string]any {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
return typed
|
||||
case map[string]float64:
|
||||
return lo.MapValues(typed, func(value float64, _ string) any { return value })
|
||||
case map[string]string:
|
||||
return lo.MapValues(typed, func(value string, _ string) any { return value })
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat64(value any) (float64, bool) {
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
return typed, true
|
||||
case float32:
|
||||
return float64(typed), true
|
||||
case int:
|
||||
return float64(typed), true
|
||||
case int64:
|
||||
return float64(typed), true
|
||||
case json.Number:
|
||||
parsed, err := typed.Float64()
|
||||
return parsed, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSyncValue(field string, value any) any {
|
||||
if numericPricingSyncFields[field] {
|
||||
if parsed, ok := asFloat64(value); ok {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func getLocalPricingSyncData() map[string]any {
|
||||
data := billing_setting.GetPricingSyncData(map[string]any(ratio_setting.GetExposedData()))
|
||||
data["image_ratio"] = ratio_setting.GetImageRatioCopy()
|
||||
data["audio_ratio"] = ratio_setting.GetAudioRatioCopy()
|
||||
data["audio_completion_ratio"] = ratio_setting.GetAudioCompletionRatioCopy()
|
||||
return data
|
||||
}
|
||||
|
||||
func FetchUpstreamRatios(c *gin.Context) {
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -293,7 +365,7 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
if err := common.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||
isType1 := false
|
||||
for _, rt := range ratioTypes {
|
||||
for _, rt := range pricingSyncFields {
|
||||
if _, ok := type1Data[rt]; ok {
|
||||
isType1 = true
|
||||
break
|
||||
@@ -307,11 +379,18 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
|
||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||
var pricingItems []struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
CacheRatio *float64 `json:"cache_ratio"`
|
||||
CreateCacheRatio *float64 `json:"create_cache_ratio"`
|
||||
ImageRatio *float64 `json:"image_ratio"`
|
||||
AudioRatio *float64 `json:"audio_ratio"`
|
||||
AudioCompletionRatio *float64 `json:"audio_completion_ratio"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
BillingExpr string `json:"billing_expr"`
|
||||
}
|
||||
if err := common.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||
@@ -321,9 +400,23 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
|
||||
modelRatioMap := make(map[string]float64)
|
||||
completionRatioMap := make(map[string]float64)
|
||||
cacheRatioMap := make(map[string]float64)
|
||||
createCacheRatioMap := make(map[string]float64)
|
||||
imageRatioMap := make(map[string]float64)
|
||||
audioRatioMap := make(map[string]float64)
|
||||
audioCompletionRatioMap := make(map[string]float64)
|
||||
modelPriceMap := make(map[string]float64)
|
||||
billingModeMap := make(map[string]string)
|
||||
billingExprMap := make(map[string]string)
|
||||
|
||||
for _, item := range pricingItems {
|
||||
if item.ModelName == "" {
|
||||
continue
|
||||
}
|
||||
if item.BillingMode == billing_setting.BillingModeTieredExpr && strings.TrimSpace(item.BillingExpr) != "" {
|
||||
billingModeMap[item.ModelName] = billing_setting.BillingModeTieredExpr
|
||||
billingExprMap[item.ModelName] = item.BillingExpr
|
||||
}
|
||||
if item.QuotaType == 1 {
|
||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||
} else {
|
||||
@@ -331,6 +424,21 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||
}
|
||||
if item.CacheRatio != nil {
|
||||
cacheRatioMap[item.ModelName] = *item.CacheRatio
|
||||
}
|
||||
if item.CreateCacheRatio != nil {
|
||||
createCacheRatioMap[item.ModelName] = *item.CreateCacheRatio
|
||||
}
|
||||
if item.ImageRatio != nil {
|
||||
imageRatioMap[item.ModelName] = *item.ImageRatio
|
||||
}
|
||||
if item.AudioRatio != nil {
|
||||
audioRatioMap[item.ModelName] = *item.AudioRatio
|
||||
}
|
||||
if item.AudioCompletionRatio != nil {
|
||||
audioCompletionRatioMap[item.ModelName] = *item.AudioCompletionRatio
|
||||
}
|
||||
}
|
||||
|
||||
converted := make(map[string]any)
|
||||
@@ -350,6 +458,21 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
}
|
||||
converted["completion_ratio"] = compAny
|
||||
}
|
||||
if len(cacheRatioMap) > 0 {
|
||||
converted["cache_ratio"] = valueMap(cacheRatioMap)
|
||||
}
|
||||
if len(createCacheRatioMap) > 0 {
|
||||
converted["create_cache_ratio"] = valueMap(createCacheRatioMap)
|
||||
}
|
||||
if len(imageRatioMap) > 0 {
|
||||
converted["image_ratio"] = valueMap(imageRatioMap)
|
||||
}
|
||||
if len(audioRatioMap) > 0 {
|
||||
converted["audio_ratio"] = valueMap(audioRatioMap)
|
||||
}
|
||||
if len(audioCompletionRatioMap) > 0 {
|
||||
converted["audio_completion_ratio"] = valueMap(audioCompletionRatioMap)
|
||||
}
|
||||
|
||||
if len(modelPriceMap) > 0 {
|
||||
priceAny := make(map[string]any, len(modelPriceMap))
|
||||
@@ -358,6 +481,12 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
}
|
||||
converted["model_price"] = priceAny
|
||||
}
|
||||
if len(billingModeMap) > 0 {
|
||||
converted[billing_setting.BillingModeField] = valueMap(billingModeMap)
|
||||
}
|
||||
if len(billingExprMap) > 0 {
|
||||
converted[billing_setting.BillingExprField] = valueMap(billingExprMap)
|
||||
}
|
||||
|
||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||
}(chn)
|
||||
@@ -366,7 +495,7 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
|
||||
localData := ratio_setting.GetExposedData()
|
||||
localData := getLocalPricingSyncData()
|
||||
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
@@ -412,22 +541,16 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, field := range pricingSyncFields {
|
||||
for modelName := range valueMap(localData[field]) {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
for _, field := range pricingSyncFields {
|
||||
for modelName := range valueMap(channel.data[field]) {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -438,10 +561,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
for _, channel := range successfulChannels {
|
||||
confidenceMap[channel.name] = make(map[string]bool)
|
||||
|
||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||
modelRatios := valueMap(channel.data["model_ratio"])
|
||||
completionRatios := valueMap(channel.data["completion_ratio"])
|
||||
|
||||
if hasModelRatio && hasCompletionRatio {
|
||||
if len(modelRatios) > 0 && len(completionRatios) > 0 {
|
||||
// 遍历所有模型,检查是否满足不可信条件
|
||||
for modelName := range allModels {
|
||||
// 默认为可信
|
||||
@@ -451,12 +574,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||
// 转换为float64进行比较
|
||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
modelRatioFloat, modelRatioOK := asFloat64(modelRatioVal)
|
||||
completionRatioFloat, completionRatioOK := asFloat64(completionRatioVal)
|
||||
if modelRatioOK && completionRatioOK && nearlyEqual(modelRatioFloat, 37.5) && nearlyEqual(completionRatioFloat, 1.0) {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -470,14 +591,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
}
|
||||
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
for _, ratioType := range pricingSyncFields {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
if val, exists := valueMap(localData[ratioType])[modelName]; exists {
|
||||
localValue = normalizeSyncValue(ratioType, val)
|
||||
}
|
||||
|
||||
upstreamValues := make(map[string]interface{})
|
||||
@@ -488,16 +605,14 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
if val, exists := valueMap(channel.data[ratioType])[modelName]; exists {
|
||||
upstreamValue = normalizeSyncValue(ratioType, val)
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && !valuesEqual(localValue, val) {
|
||||
hasDifference = true
|
||||
} else if valuesEqual(localValue, val) {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
if localValue != nil && !valuesEqual(localValue, upstreamValue) {
|
||||
hasDifference = true
|
||||
} else if valuesEqual(localValue, upstreamValue) {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
|
||||
@@ -13,7 +13,10 @@ import (
|
||||
|
||||
const (
|
||||
// SecureVerificationSessionKey means the user has fully passed secure verification.
|
||||
SecureVerificationSessionKey = "secure_verified_at"
|
||||
SecureVerificationSessionKey = "secure_verified_at"
|
||||
secureVerificationMethodSessionKey = "secure_verified_method"
|
||||
secureVerificationMethod2FA = "2fa"
|
||||
secureVerificationMethodPasskey = "passkey"
|
||||
// PasskeyReadySessionKey means WebAuthn finished and /api/verify can finalize step-up verification.
|
||||
PasskeyReadySessionKey = "secure_passkey_ready_at"
|
||||
// SecureVerificationTimeout 验证有效期(秒)
|
||||
@@ -120,7 +123,7 @@ func UniversalVerify(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证成功,在 session 中记录时间戳
|
||||
now, err := setSecureVerificationSession(c)
|
||||
now, err := setSecureVerificationSession(c, req.Method)
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
|
||||
return
|
||||
@@ -139,11 +142,12 @@ func UniversalVerify(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func setSecureVerificationSession(c *gin.Context) (int64, error) {
|
||||
func setSecureVerificationSession(c *gin.Context, method string) (int64, error) {
|
||||
session := sessions.Default(c)
|
||||
session.Delete(PasskeyReadySessionKey)
|
||||
now := time.Now().Unix()
|
||||
session.Set(SecureVerificationSessionKey, now)
|
||||
session.Set(secureVerificationMethodSessionKey, method)
|
||||
if err := session.Save(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -2,11 +2,13 @@ package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
@@ -24,14 +26,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
|
||||
// Keep body for debugging consistency (like RequestCreemPay)
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("read subscription creem pay req body err: %v", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "read query error"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅支付请求读取失败 error=%q", err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "read query error"})
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -81,16 +83,17 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
|
||||
|
||||
// create pending order first
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: PaymentMethodCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
PaymentProvider: model.PaymentProviderCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -112,14 +115,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
|
||||
Quota: 0,
|
||||
}
|
||||
|
||||
checkoutUrl, err := genCreemLink(referenceId, product, user.Email, user.Username)
|
||||
checkoutUrl, err := genCreemLink(c.Request.Context(), referenceId, product, user.Email, user.Username)
|
||||
if err != nil {
|
||||
log.Printf("获取Creem支付链接失败: %v", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅支付链接创建失败 trade_no=%s product_id=%s error=%q", referenceId, product.ProductId, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"checkout_url": checkoutUrl,
|
||||
|
||||
@@ -82,13 +82,14 @@ func SubscriptionRequestEpay(c *gin.Context) {
|
||||
}
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
PaymentProvider: model.PaymentProviderEpay,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
common.ApiErrorMsg(c, "创建订单失败")
|
||||
@@ -104,7 +105,7 @@ func SubscriptionRequestEpay(c *gin.Context) {
|
||||
ReturnUrl: returnUrl,
|
||||
})
|
||||
if err != nil {
|
||||
_ = model.ExpireSubscriptionOrder(tradeNo)
|
||||
_ = model.ExpireSubscriptionOrder(tradeNo, model.PaymentProviderEpay)
|
||||
common.ApiErrorMsg(c, "拉起支付失败")
|
||||
return
|
||||
}
|
||||
@@ -156,7 +157,7 @@ func SubscriptionEpayNotify(c *gin.Context) {
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo)); err != nil {
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
|
||||
_, _ = c.Writer.Write([]byte("fail"))
|
||||
return
|
||||
}
|
||||
@@ -205,7 +206,7 @@ func SubscriptionEpayReturn(c *gin.Context) {
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo)); err != nil {
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
|
||||
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,12 +2,12 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
@@ -78,19 +78,20 @@ func SubscriptionRequestStripePay(c *gin.Context) {
|
||||
|
||||
payLink, err := genStripeSubscriptionLink(referenceId, user.StripeCustomer, user.Email, plan.StripePriceId)
|
||||
if err != nil {
|
||||
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 订阅支付链接创建失败 trade_no=%s plan_id=%d error=%q", referenceId, plan.Id, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: PaymentMethodStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
PaymentProvider: model.PaymentProviderStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
info := &relaycommon.RelayInfo{}
|
||||
info.ChannelMeta = &relaycommon.ChannelMeta{
|
||||
ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
|
||||
}
|
||||
info.ApiKey = cacheGetChannel.Key
|
||||
adaptor.Init(info)
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
key := channel.Key
|
||||
|
||||
privateData := task.PrivateData
|
||||
if privateData.Key != "" {
|
||||
key = privateData.Key
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if resp.StatusCode != http.StatusOK {
|
||||
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||
//}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
|
||||
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
taskResult.Url = t.FailReason
|
||||
taskResult.Progress = t.Progress
|
||||
taskResult.Reason = t.FailReason
|
||||
task.Data = t.Data
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = redactVideoResponseBody(responseBody)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
//return fmt.Errorf("task %s status is empty", taskId)
|
||||
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
||||
}
|
||||
|
||||
// 记录原本的状态,防止重复退款
|
||||
shouldRefund := false
|
||||
quota := task.Quota
|
||||
preStatus := task.Status
|
||||
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
task.Progress = "10%"
|
||||
case model.TaskStatusQueued:
|
||||
task.Progress = "20%"
|
||||
case model.TaskStatusInProgress:
|
||||
task.Progress = "30%"
|
||||
if task.StartTime == 0 {
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
||||
task.FailReason = taskResult.Url
|
||||
}
|
||||
|
||||
// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
|
||||
if taskResult.TotalTokens > 0 {
|
||||
// 获取模型名称
|
||||
var taskData map[string]interface{}
|
||||
if err := json.Unmarshal(task.Data, &taskData); err == nil {
|
||||
if modelName, ok := taskData["model"].(string); ok && modelName != "" {
|
||||
// 获取模型价格和倍率
|
||||
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||||
// 只有配置了倍率(非固定价格)时才按 token 重新计费
|
||||
if hasRatioSetting && modelRatio > 0 {
|
||||
// 获取用户和组的倍率信息
|
||||
group := task.Group
|
||||
if group == "" {
|
||||
user, err := model.GetUserById(task.UserId, false)
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
}
|
||||
}
|
||||
if group != "" {
|
||||
groupRatio := ratio_setting.GetGroupRatio(group)
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
|
||||
|
||||
var finalGroupRatio float64
|
||||
if hasUserGroupRatio {
|
||||
finalGroupRatio = userGroupRatio
|
||||
} else {
|
||||
finalGroupRatio = groupRatio
|
||||
}
|
||||
|
||||
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
||||
actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
|
||||
|
||||
// 计算差额
|
||||
preConsumedQuota := task.Quota
|
||||
quotaDelta := actualQuota - preConsumedQuota
|
||||
|
||||
if quotaDelta > 0 {
|
||||
// 需要补扣费
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
||||
task.TaskID,
|
||||
logger.LogQuota(quotaDelta),
|
||||
logger.LogQuota(actualQuota),
|
||||
logger.LogQuota(preConsumedQuota),
|
||||
taskResult.TotalTokens,
|
||||
))
|
||||
if err := model.DecreaseUserQuota(task.UserId, quotaDelta, false); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
||||
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
||||
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
||||
|
||||
// 记录消费日志
|
||||
logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
|
||||
modelRatio, finalGroupRatio, taskResult.TotalTokens,
|
||||
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
} else if quotaDelta < 0 {
|
||||
// 需要退还多扣的费用
|
||||
refundQuota := -quotaDelta
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
||||
task.TaskID,
|
||||
logger.LogQuota(refundQuota),
|
||||
logger.LogQuota(actualQuota),
|
||||
logger.LogQuota(preConsumedQuota),
|
||||
taskResult.TotalTokens,
|
||||
))
|
||||
if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
|
||||
} else {
|
||||
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
||||
|
||||
// 记录退款日志
|
||||
logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
|
||||
modelRatio, finalGroupRatio, taskResult.TotalTokens,
|
||||
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
} else {
|
||||
// quotaDelta == 0, 预扣费刚好准确
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
|
||||
task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
taskResult.Progress = "100%"
|
||||
if quota != 0 {
|
||||
if preStatus != model.TaskStatusFailure {
|
||||
shouldRefund = true
|
||||
} else {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||
}
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||
shouldRefund = false
|
||||
}
|
||||
|
||||
if shouldRefund {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactVideoResponseBody(body []byte) []byte {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return body
|
||||
}
|
||||
resp, _ := m["response"].(map[string]any)
|
||||
if resp != nil {
|
||||
delete(resp, "bytesBase64Encoded")
|
||||
if v, ok := resp["video"].(string); ok {
|
||||
resp["video"] = truncateBase64(v)
|
||||
}
|
||||
if vs, ok := resp["videos"].([]any); ok {
|
||||
for i := range vs {
|
||||
if vm, ok := vs[i].(map[string]any); ok {
|
||||
delete(vm, "bytesBase64Encoded")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func truncateBase64(s string) string {
|
||||
const maxKeep = 256
|
||||
if len(s) <= maxKeep {
|
||||
return s
|
||||
}
|
||||
return s[:maxKeep] + "..."
|
||||
}
|
||||
+271
-5
@@ -2,10 +2,12 @@ package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -14,6 +16,8 @@ import (
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -38,7 +42,36 @@ type tokenKeyResponse struct {
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
||||
type sqliteColumnInfo struct {
|
||||
Name string `gorm:"column:name"`
|
||||
Type string `gorm:"column:type"`
|
||||
}
|
||||
|
||||
type legacyToken struct {
|
||||
Id int `gorm:"primaryKey"`
|
||||
UserId int `gorm:"index"`
|
||||
Key string `gorm:"column:key;type:char(48);uniqueIndex"`
|
||||
Status int `gorm:"default:1"`
|
||||
Name string `gorm:"index"`
|
||||
CreatedTime int64 `gorm:"bigint"`
|
||||
AccessedTime int64 `gorm:"bigint"`
|
||||
ExpiredTime int64 `gorm:"bigint;default:-1"`
|
||||
RemainQuota int `gorm:"default:0"`
|
||||
UnlimitedQuota bool
|
||||
ModelLimitsEnabled bool
|
||||
ModelLimits string `gorm:"type:text"`
|
||||
AllowIps *string `gorm:"default:''"`
|
||||
UsedQuota int `gorm:"default:0"`
|
||||
Group string `gorm:"column:group;default:''"`
|
||||
CrossGroupRetry bool
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func (legacyToken) TableName() string {
|
||||
return "tokens"
|
||||
}
|
||||
|
||||
func openTokenControllerTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
@@ -55,10 +88,6 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
||||
model.DB = db
|
||||
model.LOG_DB = db
|
||||
|
||||
if err := db.AutoMigrate(&model.Token{}); err != nil {
|
||||
t.Fatalf("failed to migrate token table: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
sqlDB, err := db.DB()
|
||||
if err == nil {
|
||||
@@ -69,6 +98,69 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
func migrateTokenControllerTestDB(t *testing.T, db *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
if err := db.AutoMigrate(&model.Token{}); err != nil {
|
||||
t.Fatalf("failed to migrate token table: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
db := openTokenControllerTestDB(t)
|
||||
migrateTokenControllerTestDB(t, db)
|
||||
return db
|
||||
}
|
||||
|
||||
func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) (*gorm.DB, *bool) {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
common.RedisEnabled = false
|
||||
common.UsingSQLite = false
|
||||
common.UsingMySQL = dialect == "mysql"
|
||||
common.UsingPostgreSQL = dialect == "postgres"
|
||||
|
||||
var (
|
||||
db *gorm.DB
|
||||
err error
|
||||
)
|
||||
switch dialect {
|
||||
case "mysql":
|
||||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||
case "postgres":
|
||||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
default:
|
||||
t.Fatalf("unsupported dialect %q", dialect)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open %s db: %v", dialect, err)
|
||||
}
|
||||
|
||||
model.DB = db
|
||||
model.LOG_DB = db
|
||||
|
||||
if db.Migrator().HasTable("tokens") {
|
||||
t.Skipf("refusing to run %s migration compatibility test against external database because tokens table already exists", dialect)
|
||||
}
|
||||
|
||||
managedTokensTable := new(bool)
|
||||
|
||||
t.Cleanup(func() {
|
||||
if *managedTokensTable && db.Migrator().HasTable("tokens") {
|
||||
_ = db.Migrator().DropTable("tokens")
|
||||
}
|
||||
sqlDB, err := db.DB()
|
||||
if err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return db, managedTokensTable
|
||||
}
|
||||
|
||||
func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
|
||||
t.Helper()
|
||||
|
||||
@@ -124,6 +216,180 @@ func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenA
|
||||
return response
|
||||
}
|
||||
|
||||
func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName string) string {
|
||||
t.Helper()
|
||||
|
||||
var columns []sqliteColumnInfo
|
||||
if err := db.Raw("PRAGMA table_info(" + tableName + ")").Scan(&columns).Error; err != nil {
|
||||
t.Fatalf("failed to inspect %s schema: %v", tableName, err)
|
||||
}
|
||||
|
||||
for _, column := range columns {
|
||||
if column.Name == columnName {
|
||||
return strings.ToLower(column.Type)
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatalf("column %s not found in %s schema", columnName, tableName)
|
||||
return ""
|
||||
}
|
||||
|
||||
func getTokenKeyColumnType(t *testing.T, db *gorm.DB, dialect string) string {
|
||||
t.Helper()
|
||||
|
||||
switch dialect {
|
||||
case "sqlite":
|
||||
return getSQLiteColumnType(t, db, "tokens", "key")
|
||||
case "mysql":
|
||||
var columnType string
|
||||
if err := db.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
|
||||
"tokens", "key").Scan(&columnType).Error; err != nil {
|
||||
t.Fatalf("failed to inspect mysql token key column: %v", err)
|
||||
}
|
||||
return strings.ToLower(columnType)
|
||||
case "postgres":
|
||||
var dataType string
|
||||
var maxLength sql.NullInt64
|
||||
if err := db.Raw(`SELECT data_type, character_maximum_length
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
|
||||
"tokens", "key").Row().Scan(&dataType, &maxLength); err != nil {
|
||||
t.Fatalf("failed to inspect postgres token key column: %v", err)
|
||||
}
|
||||
switch strings.ToLower(dataType) {
|
||||
case "character varying":
|
||||
return fmt.Sprintf("varchar(%d)", maxLength.Int64)
|
||||
case "character":
|
||||
return fmt.Sprintf("char(%d)", maxLength.Int64)
|
||||
default:
|
||||
if maxLength.Valid {
|
||||
return fmt.Sprintf("%s(%d)", strings.ToLower(dataType), maxLength.Int64)
|
||||
}
|
||||
return strings.ToLower(dataType)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("unsupported dialect %q", dialect)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string, managedTokensTable *bool) {
|
||||
t.Helper()
|
||||
|
||||
legacyKey := strings.Repeat("a", 48)
|
||||
longKey := strings.Repeat("b", 64)
|
||||
|
||||
if err := db.AutoMigrate(&legacyToken{}); err != nil {
|
||||
t.Fatalf("failed to create legacy token schema: %v", err)
|
||||
}
|
||||
if managedTokensTable != nil {
|
||||
*managedTokensTable = true
|
||||
}
|
||||
if err := db.Create(&legacyToken{
|
||||
UserId: 7,
|
||||
Key: legacyKey,
|
||||
Status: common.TokenStatusEnabled,
|
||||
Name: "legacy-token",
|
||||
CreatedTime: 1,
|
||||
AccessedTime: 1,
|
||||
ExpiredTime: -1,
|
||||
RemainQuota: 100,
|
||||
UnlimitedQuota: true,
|
||||
ModelLimitsEnabled: false,
|
||||
ModelLimits: "",
|
||||
AllowIps: common.GetPointer(""),
|
||||
UsedQuota: 0,
|
||||
Group: "default",
|
||||
CrossGroupRetry: false,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to seed legacy token row: %v", err)
|
||||
}
|
||||
|
||||
if got := getTokenKeyColumnType(t, db, dialect); got != "char(48)" {
|
||||
t.Fatalf("expected legacy key column type char(48), got %q", got)
|
||||
}
|
||||
|
||||
migrateTokenControllerTestDB(t, db)
|
||||
|
||||
if got := getTokenKeyColumnType(t, db, dialect); got != "varchar(128)" {
|
||||
t.Fatalf("expected migrated key column type varchar(128), got %q", got)
|
||||
}
|
||||
|
||||
var migratedToken model.Token
|
||||
if err := db.First(&migratedToken, "name = ?", "legacy-token").Error; err != nil {
|
||||
t.Fatalf("failed to load migrated token row: %v", err)
|
||||
}
|
||||
if migratedToken.Key != legacyKey {
|
||||
t.Fatalf("expected migrated token key %q, got %q", legacyKey, migratedToken.Key)
|
||||
}
|
||||
if migratedToken.Name != "legacy-token" {
|
||||
t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
|
||||
}
|
||||
|
||||
inserted := model.Token{
|
||||
UserId: 8,
|
||||
Name: "long-token",
|
||||
Key: longKey,
|
||||
Status: common.TokenStatusEnabled,
|
||||
CreatedTime: 1,
|
||||
AccessedTime: 1,
|
||||
ExpiredTime: -1,
|
||||
RemainQuota: 200,
|
||||
UnlimitedQuota: true,
|
||||
ModelLimitsEnabled: false,
|
||||
ModelLimits: "",
|
||||
AllowIps: common.GetPointer(""),
|
||||
UsedQuota: 0,
|
||||
Group: "default",
|
||||
CrossGroupRetry: false,
|
||||
}
|
||||
if err := db.Create(&inserted).Error; err != nil {
|
||||
t.Fatalf("failed to insert long token after migration: %v", err)
|
||||
}
|
||||
|
||||
var fetched model.Token
|
||||
if err := db.First(&fetched, "id = ?", inserted.Id).Error; err != nil {
|
||||
t.Fatalf("failed to fetch long token after migration: %v", err)
|
||||
}
|
||||
if fetched.Key != longKey {
|
||||
t.Fatalf("expected long token key %q, got %q", longKey, fetched.Key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
|
||||
db := setupTokenControllerTestDB(t)
|
||||
|
||||
if got := getTokenKeyColumnType(t, db, "sqlite"); got != "varchar(128)" {
|
||||
t.Fatalf("expected key column type varchar(128), got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
|
||||
db := openTokenControllerTestDB(t)
|
||||
runTokenMigrationCompatibilityTest(t, db, "sqlite", nil)
|
||||
}
|
||||
|
||||
func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) {
|
||||
dsn := os.Getenv("TEST_MYSQL_DSN")
|
||||
if dsn == "" {
|
||||
t.Skip("set TEST_MYSQL_DSN to run mysql migration compatibility test")
|
||||
}
|
||||
|
||||
db, managedTokensTable := openTokenControllerExternalDB(t, "mysql", dsn)
|
||||
runTokenMigrationCompatibilityTest(t, db, "mysql", managedTokensTable)
|
||||
}
|
||||
|
||||
func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) {
|
||||
dsn := os.Getenv("TEST_POSTGRES_DSN")
|
||||
if dsn == "" {
|
||||
t.Skip("set TEST_POSTGRES_DSN to run postgres migration compatibility test")
|
||||
}
|
||||
|
||||
db, managedTokensTable := openTokenControllerExternalDB(t, "postgres", dsn)
|
||||
runTokenMigrationCompatibilityTest(t, db, "postgres", managedTokensTable)
|
||||
}
|
||||
|
||||
func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
|
||||
db := setupTokenControllerTestDB(t)
|
||||
token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
|
||||
|
||||
+97
-65
@@ -2,7 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -27,7 +27,7 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
payMethods := operation_setting.PayMethods
|
||||
|
||||
// 如果启用了 Stripe 支付,添加到支付方法列表
|
||||
if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
|
||||
if isStripeTopUpEnabled() {
|
||||
// 检查是否已经包含 Stripe
|
||||
hasStripe := false
|
||||
for _, method := range payMethods {
|
||||
@@ -49,19 +49,11 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果启用了 Waffo 支付,添加到支付方法列表
|
||||
enableWaffo := setting.WaffoEnabled &&
|
||||
((!setting.WaffoSandbox &&
|
||||
setting.WaffoApiKey != "" &&
|
||||
setting.WaffoPrivateKey != "" &&
|
||||
setting.WaffoPublicCert != "") ||
|
||||
(setting.WaffoSandbox &&
|
||||
setting.WaffoSandboxApiKey != "" &&
|
||||
setting.WaffoSandboxPrivateKey != "" &&
|
||||
setting.WaffoSandboxPublicCert != ""))
|
||||
enableWaffo := isWaffoTopUpEnabled()
|
||||
if enableWaffo {
|
||||
hasWaffo := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == "waffo" {
|
||||
if method["type"] == model.PaymentMethodWaffo {
|
||||
hasWaffo = true
|
||||
break
|
||||
}
|
||||
@@ -70,7 +62,7 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
if !hasWaffo {
|
||||
waffoMethod := map[string]string{
|
||||
"name": "Waffo (Global Payment)",
|
||||
"type": "waffo",
|
||||
"type": model.PaymentMethodWaffo,
|
||||
"color": "rgba(var(--semi-blue-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.WaffoMinTopUp),
|
||||
}
|
||||
@@ -78,24 +70,46 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
|
||||
if enableWaffoPancake {
|
||||
hasWaffoPancake := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == model.PaymentMethodWaffoPancake {
|
||||
hasWaffoPancake = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasWaffoPancake {
|
||||
payMethods = append(payMethods, map[string]string{
|
||||
"name": "Waffo Pancake",
|
||||
"type": model.PaymentMethodWaffoPancake,
|
||||
"color": "rgba(var(--semi-orange-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"enable_creem_topup": setting.CreemApiKey != "" && setting.CreemProducts != "[]",
|
||||
"enable_waffo_topup": enableWaffo,
|
||||
"enable_online_topup": isEpayTopUpEnabled(),
|
||||
"enable_stripe_topup": isStripeTopUpEnabled(),
|
||||
"enable_creem_topup": isCreemTopUpEnabled(),
|
||||
"enable_waffo_topup": enableWaffo,
|
||||
"enable_waffo_pancake_topup": enableWaffoPancake,
|
||||
"waffo_pay_methods": func() interface{} {
|
||||
if enableWaffo {
|
||||
return setting.GetWaffoPayMethods()
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
"creem_products": setting.CreemProducts,
|
||||
"pay_methods": payMethods,
|
||||
"min_topup": operation_setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"waffo_min_topup": setting.WaffoMinTopUp,
|
||||
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
|
||||
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
|
||||
"creem_products": setting.CreemProducts,
|
||||
"pay_methods": payMethods,
|
||||
"min_topup": operation_setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"waffo_min_topup": setting.WaffoMinTopUp,
|
||||
"waffo_pancake_min_topup": setting.WaffoPancakeMinTopUp,
|
||||
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
|
||||
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
@@ -167,28 +181,28 @@ func RequestEpay(c *gin.Context) {
|
||||
var req EpayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getPayMoney(req.Amount, group)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
|
||||
if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -199,7 +213,7 @@ func RequestEpay(c *gin.Context) {
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
|
||||
return
|
||||
}
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
@@ -212,7 +226,8 @@ func RequestEpay(c *gin.Context) {
|
||||
ReturnUrl: returnUrl,
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 拉起支付失败 user_id=%d trade_no=%s payment_method=%s amount=%d error=%q", id, tradeNo, req.PaymentMethod, req.Amount, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
amount := req.Amount
|
||||
@@ -222,20 +237,23 @@ func RequestEpay(c *gin.Context) {
|
||||
amount = dAmount.Div(dQuotaPerUnit).IntPart()
|
||||
}
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: "pending",
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
PaymentProvider: model.PaymentProviderEpay,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 创建充值订单失败 user_id=%d trade_no=%s payment_method=%s amount=%d error=%q", id, tradeNo, req.PaymentMethod, req.Amount, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 充值订单创建成功 user_id=%d trade_no=%s payment_method=%s amount=%d money=%.2f uri=%q params=%q", id, tradeNo, req.PaymentMethod, req.Amount, payMoney, uri, common.GetJsonString(params)))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success", "data": params, "url": uri})
|
||||
}
|
||||
|
||||
// tradeNo lock
|
||||
@@ -281,12 +299,18 @@ func UnlockOrder(tradeNo string) {
|
||||
}
|
||||
|
||||
func EpayNotify(c *gin.Context) {
|
||||
if !isEpayWebhookEnabled() {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
|
||||
_, _ = c.Writer.Write([]byte("fail"))
|
||||
return
|
||||
}
|
||||
|
||||
var params map[string]string
|
||||
|
||||
if c.Request.Method == "POST" {
|
||||
// POST 请求:从 POST body 解析参数
|
||||
if err := c.Request.ParseForm(); err != nil {
|
||||
log.Println("易支付回调POST解析失败:", err)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook POST 表单解析失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
_, _ = c.Writer.Write([]byte("fail"))
|
||||
return
|
||||
}
|
||||
@@ -301,54 +325,63 @@ func EpayNotify(c *gin.Context) {
|
||||
return r
|
||||
}, map[string]string{})
|
||||
}
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 收到请求 path=%q client_ip=%s method=%s params=%q", c.Request.RequestURI, c.ClientIP(), c.Request.Method, common.GetJsonString(params)))
|
||||
|
||||
if len(params) == 0 {
|
||||
log.Println("易支付回调参数为空")
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 参数为空 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
|
||||
_, _ = c.Writer.Write([]byte("fail"))
|
||||
return
|
||||
}
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
log.Println("易支付回调失败 未找到配置信息")
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 client 未初始化 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
verifyInfo, err := client.Verify(params)
|
||||
if err == nil && verifyInfo.VerifyStatus {
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签成功 trade_no=%s callback_type=%s trade_status=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, verifyInfo.TradeStatus, c.ClientIP(), common.GetJsonString(verifyInfo)))
|
||||
_, err := c.Writer.Write([]byte("success"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 trade_no=%s client_ip=%s error=%q", verifyInfo.ServiceTradeNo, c.ClientIP(), err.Error()))
|
||||
}
|
||||
} else {
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
}
|
||||
if err != nil {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签失败 path=%q client_ip=%s verify_error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
} else {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签失败 path=%q client_ip=%s verify_status=false", c.Request.RequestURI, c.ClientIP()))
|
||||
}
|
||||
log.Println("易支付回调签名验证失败")
|
||||
return
|
||||
}
|
||||
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
log.Println(verifyInfo)
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
|
||||
if topUp == nil {
|
||||
log.Printf("易支付回调未找到订单: %v", verifyInfo)
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo)))
|
||||
return
|
||||
}
|
||||
if topUp.PaymentMethod == "stripe" || topUp.PaymentMethod == "creem" || topUp.PaymentMethod == "waffo" {
|
||||
log.Printf("易支付回调订单支付方式不匹配: %s, 订单号: %s", topUp.PaymentMethod, verifyInfo.ServiceTradeNo)
|
||||
if topUp.PaymentProvider != model.PaymentProviderEpay {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付网关不匹配 trade_no=%s order_provider=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentProvider, verifyInfo.Type, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
if topUp.Status == "pending" {
|
||||
topUp.Status = "success"
|
||||
if topUp.Status == common.TopUpStatusPending {
|
||||
if topUp.PaymentMethod != verifyInfo.Type {
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 实际支付方式与订单不同 trade_no=%s order_payment_method=%s actual_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
|
||||
topUp.PaymentMethod = verifyInfo.Type
|
||||
}
|
||||
topUp.Status = common.TopUpStatusSuccess
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新订单失败: %v", topUp)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 更新充值订单失败 trade_no=%s user_id=%d client_ip=%s error=%q topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), err.Error(), common.GetJsonString(topUp)))
|
||||
return
|
||||
}
|
||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
||||
@@ -358,14 +391,14 @@ func EpayNotify(c *gin.Context) {
|
||||
quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart())
|
||||
err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true)
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 更新用户额度失败 trade_no=%s user_id=%d client_ip=%s quota_to_add=%d error=%q topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), quotaToAdd, err.Error(), common.GetJsonString(topUp)))
|
||||
return
|
||||
}
|
||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 充值成功 trade_no=%s user_id=%d client_ip=%s quota_to_add=%d money=%.2f topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), quotaToAdd, topUp.Money, common.GetJsonString(topUp)))
|
||||
model.RecordTopupLog(topUp.UserId, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money), c.ClientIP(), topUp.PaymentMethod, "epay")
|
||||
}
|
||||
} else {
|
||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 忽略事件 trade_no=%s callback_type=%s trade_status=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, verifyInfo.TradeStatus, c.ClientIP(), common.GetJsonString(verifyInfo)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,26 +406,26 @@ func RequestAmount(c *gin.Context) {
|
||||
var req AmountRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getPayMoney(req.Amount, group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
}
|
||||
|
||||
func GetUserTopUps(c *gin.Context) {
|
||||
@@ -461,10 +494,9 @@ func AdminCompleteTopUp(c *gin.Context) {
|
||||
LockOrder(req.TradeNo)
|
||||
defer UnlockOrder(req.TradeNo)
|
||||
|
||||
if err := model.ManualCompleteTopUp(req.TradeNo); err != nil {
|
||||
if err := model.ManualCompleteTopUp(req.TradeNo, c.ClientIP()); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
|
||||
|
||||
+64
-72
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -9,10 +10,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -20,10 +21,7 @@ import (
|
||||
"github.com/thanhpk/randstr"
|
||||
)
|
||||
|
||||
const (
|
||||
PaymentMethodCreem = "creem"
|
||||
CreemSignatureHeader = "creem-signature"
|
||||
)
|
||||
const CreemSignatureHeader = "creem-signature"
|
||||
|
||||
var creemAdaptor = &CreemAdaptor{}
|
||||
|
||||
@@ -37,9 +35,9 @@ func generateCreemSignature(payload string, secret string) string {
|
||||
// 验证Creem webhook签名
|
||||
func verifyCreemSignature(payload string, signature string, secret string) bool {
|
||||
if secret == "" {
|
||||
log.Printf("Creem webhook secret not set")
|
||||
logger.LogWarn(context.Background(), fmt.Sprintf("Creem webhook secret 未配置 test_mode=%t signature=%q body=%q", setting.CreemTestMode, signature, payload))
|
||||
if setting.CreemTestMode {
|
||||
log.Printf("Skip Creem webhook sign verify in test mode")
|
||||
logger.LogInfo(context.Background(), fmt.Sprintf("Creem webhook 验签已跳过 reason=test_mode signature=%q body=%q", signature, payload))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -66,13 +64,13 @@ type CreemAdaptor struct {
|
||||
}
|
||||
|
||||
func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
|
||||
if req.PaymentMethod != PaymentMethodCreem {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||
if req.PaymentMethod != model.PaymentMethodCreem {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProductId == "" {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "请选择产品"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "请选择产品"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -80,8 +78,8 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
|
||||
var products []CreemProduct
|
||||
err := json.Unmarshal([]byte(setting.CreemProducts), &products)
|
||||
if err != nil {
|
||||
log.Println("解析Creem产品列表失败", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "产品配置错误"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 产品配置解析失败 user_id=%d error=%q", c.GetInt("id"), err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "产品配置错误"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -95,7 +93,7 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
|
||||
}
|
||||
|
||||
if selectedProduct == nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "产品不存在"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "产品不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -108,33 +106,33 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
|
||||
|
||||
// 先创建订单记录,使用产品配置的金额和充值额度
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: selectedProduct.Quota, // 充值额度
|
||||
Money: selectedProduct.Price, // 支付金额
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: PaymentMethodCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: selectedProduct.Quota, // 充值额度
|
||||
Money: selectedProduct.Price, // 支付金额
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
PaymentProvider: model.PaymentProviderCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
log.Printf("创建Creem订单失败: %v", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 创建充值订单失败 user_id=%d trade_no=%s product_id=%s error=%q", id, referenceId, selectedProduct.ProductId, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建支付链接,传入用户邮箱
|
||||
checkoutUrl, err := genCreemLink(referenceId, selectedProduct, user.Email, user.Username)
|
||||
checkoutUrl, err := genCreemLink(c.Request.Context(), referenceId, selectedProduct, user.Email, user.Username)
|
||||
if err != nil {
|
||||
log.Printf("获取Creem支付链接失败: %v", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 创建支付链接失败 user_id=%d trade_no=%s product_id=%s error=%q", id, referenceId, selectedProduct.ProductId, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Creem订单创建成功 - 用户ID: %d, 订单号: %s, 产品: %s, 充值额度: %d, 支付金额: %.2f",
|
||||
id, referenceId, selectedProduct.Name, selectedProduct.Quota, selectedProduct.Price)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值订单创建成功 user_id=%d trade_no=%s product_id=%s product_name=%q quota=%d money=%.2f", id, referenceId, selectedProduct.ProductId, selectedProduct.Name, selectedProduct.Quota, selectedProduct.Price))
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"checkout_url": checkoutUrl,
|
||||
@@ -149,20 +147,19 @@ func RequestCreemPay(c *gin.Context) {
|
||||
// 读取body内容用于打印,同时保留原始数据供后续使用
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("read creem pay req body err: %v", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "read query error"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 支付请求读取失败 error=%q", err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "read query error"})
|
||||
return
|
||||
}
|
||||
|
||||
// 打印body内容
|
||||
log.Printf("creem pay request body: %s", string(bodyBytes))
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 支付请求已收到 user_id=%d body=%q", c.GetInt("id"), string(bodyBytes)))
|
||||
|
||||
// 重新设置body供后续的ShouldBindJSON使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
err = c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
creemAdaptor.RequestPay(c, &req)
|
||||
@@ -230,35 +227,37 @@ type CreemWebhookEvent struct {
|
||||
}
|
||||
|
||||
func CreemWebhook(c *gin.Context) {
|
||||
if !isCreemWebhookEnabled() {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取body内容用于打印,同时保留原始数据供后续使用
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("读取Creem Webhook请求body失败: %v", err)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取签名头
|
||||
signature := c.GetHeader(CreemSignatureHeader)
|
||||
|
||||
// 打印关键信息(避免输出完整敏感payload)
|
||||
log.Printf("Creem Webhook - URI: %s", c.Request.RequestURI)
|
||||
if setting.CreemTestMode {
|
||||
log.Printf("Creem Webhook - Signature: %s , Body: %s", signature, bodyBytes)
|
||||
} else if signature == "" {
|
||||
log.Printf("Creem Webhook缺少签名头")
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
|
||||
if signature == "" {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 缺少签名 path=%q client_ip=%s body=%q", c.Request.RequestURI, c.ClientIP(), string(bodyBytes)))
|
||||
c.AbortWithStatus(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证签名
|
||||
if !verifyCreemSignature(string(bodyBytes), signature, setting.CreemWebhookSecret) {
|
||||
log.Printf("Creem Webhook签名验证失败")
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 验签失败 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
|
||||
c.AbortWithStatus(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Creem Webhook签名验证成功")
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 验签成功 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
|
||||
|
||||
// 重新设置body供后续的ShouldBindJSON使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
@@ -266,19 +265,19 @@ func CreemWebhook(c *gin.Context) {
|
||||
// 解析新格式的webhook数据
|
||||
var webhookEvent CreemWebhookEvent
|
||||
if err := c.ShouldBindJSON(&webhookEvent); err != nil {
|
||||
log.Printf("解析Creem Webhook参数失败: %v", err)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem webhook 解析失败 path=%q client_ip=%s error=%q body=%q", c.Request.RequestURI, c.ClientIP(), err.Error(), string(bodyBytes)))
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Creem Webhook解析成功 - EventType: %s, EventId: %s", webhookEvent.EventType, webhookEvent.Id)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 解析成功 event_type=%s event_id=%s request_id=%s order_id=%s order_status=%s", webhookEvent.EventType, webhookEvent.Id, webhookEvent.Object.RequestId, webhookEvent.Object.Order.Id, webhookEvent.Object.Order.Status))
|
||||
|
||||
// 根据事件类型处理不同的webhook
|
||||
switch webhookEvent.EventType {
|
||||
case "checkout.completed":
|
||||
handleCheckoutCompleted(c, &webhookEvent)
|
||||
default:
|
||||
log.Printf("忽略Creem Webhook事件类型: %s", webhookEvent.EventType)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 忽略事件 event_type=%s event_id=%s", webhookEvent.EventType, webhookEvent.Id))
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
}
|
||||
@@ -287,7 +286,7 @@ func CreemWebhook(c *gin.Context) {
|
||||
func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
|
||||
// 验证订单状态
|
||||
if event.Object.Order.Status != "paid" {
|
||||
log.Printf("订单状态不是已支付: %s, 跳过处理", event.Object.Order.Status)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订单状态未支付,忽略处理 request_id=%s order_id=%s order_status=%s", event.Object.RequestId, event.Object.Order.Id, event.Object.Order.Status))
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
@@ -295,7 +294,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
|
||||
// 获取引用ID(这是我们创建订单时传递的request_id)
|
||||
referenceId := event.Object.RequestId
|
||||
if referenceId == "" {
|
||||
log.Println("Creem Webhook缺少request_id字段")
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 缺少 request_id event_id=%s order_id=%s", event.Id, event.Object.Order.Id))
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -303,40 +302,35 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
|
||||
// Try complete subscription order first
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event)); err == nil {
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentProviderCreem, ""); err == nil {
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
log.Printf("Creem订阅订单处理失败: %s, 订单号: %s", err.Error(), referenceId)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理失败 trade_no=%s creem_order_id=%s error=%q", referenceId, event.Object.Order.Id, err.Error()))
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证订单类型,目前只处理一次性付款(充值)
|
||||
if event.Object.Order.Type != "onetime" {
|
||||
log.Printf("暂不支持的订单类型: %s, 跳过处理", event.Object.Order.Type)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 暂不支持该订单类型,忽略处理 request_id=%s creem_order_id=%s order_type=%s", referenceId, event.Object.Order.Id, event.Object.Order.Type))
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录详细的支付信息
|
||||
log.Printf("处理Creem支付完成 - 订单号: %s, Creem订单ID: %s, 支付金额: %d %s, 客户邮箱: <redacted>, 产品: %s",
|
||||
referenceId,
|
||||
event.Object.Order.Id,
|
||||
event.Object.Order.AmountPaid,
|
||||
event.Object.Order.Currency,
|
||||
event.Object.Product.Name)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 支付完成回调 trade_no=%s creem_order_id=%s amount_paid=%d currency=%s product_name=%q customer_email=%q customer_name=%q", referenceId, event.Object.Order.Id, event.Object.Order.AmountPaid, event.Object.Order.Currency, event.Object.Product.Name, event.Object.Customer.Email, event.Object.Customer.Name))
|
||||
|
||||
// 查询本地订单确认存在
|
||||
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||
if topUp == nil {
|
||||
log.Printf("Creem充值订单不存在: %s", referenceId)
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 充值订单不存在 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
log.Printf("Creem充值订单状态错误: %s, 当前状态: %s", referenceId, topUp.Status)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值订单状态非 pending,忽略处理 trade_no=%s status=%s creem_order_id=%s", referenceId, topUp.Status, event.Object.Order.Id))
|
||||
c.Status(http.StatusOK) // 已处理过的订单,返回成功避免重复处理
|
||||
return
|
||||
}
|
||||
@@ -347,21 +341,20 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
|
||||
|
||||
// 防护性检查,确保邮箱和姓名不为空字符串
|
||||
if customerEmail == "" {
|
||||
log.Printf("警告:Creem回调中客户邮箱为空 - 订单号: %s", referenceId)
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 回调客户邮箱为空 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
|
||||
}
|
||||
if customerName == "" {
|
||||
log.Printf("警告:Creem回调中客户姓名为空 - 订单号: %s", referenceId)
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 回调客户姓名为空 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
|
||||
}
|
||||
|
||||
err := model.RechargeCreem(referenceId, customerEmail, customerName)
|
||||
err := model.RechargeCreem(referenceId, customerEmail, customerName, c.ClientIP())
|
||||
if err != nil {
|
||||
log.Printf("Creem充值处理失败: %s, 订单号: %s", err.Error(), referenceId)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 充值处理失败 trade_no=%s creem_order_id=%s client_ip=%s error=%q", referenceId, event.Object.Order.Id, c.ClientIP(), err.Error()))
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Creem充值成功 - 订单号: %s, 充值额度: %d, 支付金额: %.2f",
|
||||
referenceId, topUp.Amount, topUp.Money)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值成功 trade_no=%s creem_order_id=%s quota=%d money=%.2f client_ip=%s", referenceId, event.Object.Order.Id, topUp.Amount, topUp.Money, c.ClientIP()))
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
@@ -379,7 +372,7 @@ type CreemCheckoutResponse struct {
|
||||
Id string `json:"id"`
|
||||
}
|
||||
|
||||
func genCreemLink(referenceId string, product *CreemProduct, email string, username string) (string, error) {
|
||||
func genCreemLink(ctx context.Context, referenceId string, product *CreemProduct, email string, username string) (string, error) {
|
||||
if setting.CreemApiKey == "" {
|
||||
return "", fmt.Errorf("未配置Creem API密钥")
|
||||
}
|
||||
@@ -388,7 +381,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
|
||||
apiUrl := "https://api.creem.io/v1/checkouts"
|
||||
if setting.CreemTestMode {
|
||||
apiUrl = "https://test-api.creem.io/v1/checkouts"
|
||||
log.Printf("使用Creem测试环境: %s", apiUrl)
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Creem 使用测试环境 api_url=%s", apiUrl))
|
||||
}
|
||||
|
||||
// 构建请求数据,确保包含用户邮箱
|
||||
@@ -424,8 +417,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-api-key", setting.CreemApiKey)
|
||||
|
||||
log.Printf("发送Creem支付请求 - URL: %s, 产品ID: %s, 用户邮箱: %s, 订单号: %s",
|
||||
apiUrl, product.ProductId, email, referenceId)
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Creem 支付请求已发送 api_url=%s product_id=%s email=%q trade_no=%s", apiUrl, product.ProductId, email, referenceId))
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{
|
||||
@@ -443,7 +435,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
|
||||
return "", fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("Creem API resp - status code: %d, resp: %s", resp.StatusCode, string(body))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Creem API 响应已收到 trade_no=%s status_code=%d body=%q", referenceId, resp.StatusCode, string(body)))
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode/100 != 2 {
|
||||
@@ -460,6 +452,6 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
|
||||
return "", fmt.Errorf("Creem API resp no checkout url ")
|
||||
}
|
||||
|
||||
log.Printf("Creem 支付链接创建成功 - 订单号: %s, 支付链接: %s", referenceId, checkoutResp.CheckoutUrl)
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Creem 支付链接创建成功 trade_no=%s response_id=%s checkout_url=%q", referenceId, checkoutResp.Id, checkoutResp.CheckoutUrl))
|
||||
return checkoutResp.CheckoutUrl, nil
|
||||
}
|
||||
|
||||
+74
-75
@@ -1,16 +1,17 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
@@ -23,10 +24,6 @@ import (
|
||||
"github.com/thanhpk/randstr"
|
||||
)
|
||||
|
||||
const (
|
||||
PaymentMethodStripe = "stripe"
|
||||
)
|
||||
|
||||
var stripeAdaptor = &StripeAdaptor{}
|
||||
|
||||
// StripePayRequest represents a payment request for Stripe checkout.
|
||||
@@ -48,34 +45,34 @@ type StripeAdaptor struct {
|
||||
|
||||
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
|
||||
if req.Amount < getStripeMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getStripePayMoney(float64(req.Amount), group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
}
|
||||
|
||||
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
||||
if req.PaymentMethod != PaymentMethodStripe {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||
if req.PaymentMethod != model.PaymentMethodStripe {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||
return
|
||||
}
|
||||
if req.Amount < getStripeMinTopup() {
|
||||
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
|
||||
c.JSON(http.StatusOK, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
|
||||
return
|
||||
}
|
||||
if req.Amount > 10000 {
|
||||
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -98,26 +95,29 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
||||
|
||||
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount, req.SuccessURL, req.CancelURL)
|
||||
if err != nil {
|
||||
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 创建 Checkout Session 失败 user_id=%d trade_no=%s amount=%d error=%q", id, referenceId, req.Amount, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: PaymentMethodStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
PaymentProvider: model.PaymentProviderStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, referenceId, req.Amount, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Stripe 充值订单创建成功 user_id=%d trade_no=%s amount=%d money=%.2f", id, referenceId, req.Amount, chargedMoney))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"pay_link": payLink,
|
||||
@@ -129,7 +129,7 @@ func RequestStripeAmount(c *gin.Context) {
|
||||
var req StripePayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
stripeAdaptor.RequestAmount(c, &req)
|
||||
@@ -139,89 +139,93 @@ func RequestStripePay(c *gin.Context) {
|
||||
var req StripePayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
stripeAdaptor.RequestPay(c, &req)
|
||||
}
|
||||
|
||||
func StripeWebhook(c *gin.Context) {
|
||||
if setting.StripeWebhookSecret == "" {
|
||||
log.Println("Stripe Webhook Secret 未配置,拒绝处理")
|
||||
ctx := c.Request.Context()
|
||||
if !isStripeWebhookEnabled() {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
||||
logger.LogError(ctx, fmt.Sprintf("Stripe webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
c.AbortWithStatus(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
signature := c.GetHeader("Stripe-Signature")
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(payload)))
|
||||
event, err := webhook.ConstructEventWithOptions(payload, signature, setting.StripeWebhookSecret, webhook.ConstructEventOptions{
|
||||
IgnoreAPIVersionMismatch: true,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe webhook 验签失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
callerIp := c.ClientIP()
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 验签成功 event_type=%s client_ip=%s path=%q", string(event.Type), callerIp, c.Request.RequestURI))
|
||||
switch event.Type {
|
||||
case stripe.EventTypeCheckoutSessionCompleted:
|
||||
sessionCompleted(event)
|
||||
sessionCompleted(ctx, event, callerIp)
|
||||
case stripe.EventTypeCheckoutSessionExpired:
|
||||
sessionExpired(event)
|
||||
sessionExpired(ctx, event)
|
||||
case stripe.EventTypeCheckoutSessionAsyncPaymentSucceeded:
|
||||
sessionAsyncPaymentSucceeded(event)
|
||||
sessionAsyncPaymentSucceeded(ctx, event, callerIp)
|
||||
case stripe.EventTypeCheckoutSessionAsyncPaymentFailed:
|
||||
sessionAsyncPaymentFailed(event)
|
||||
sessionAsyncPaymentFailed(ctx, event, callerIp)
|
||||
default:
|
||||
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 忽略事件 event_type=%s client_ip=%s", string(event.Type), callerIp))
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func sessionCompleted(event stripe.Event) {
|
||||
func sessionCompleted(ctx context.Context, event stripe.Event, callerIp string) {
|
||||
customerId := event.GetObjectValue("customer")
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "complete" != status {
|
||||
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe checkout.completed 状态异常,忽略处理 trade_no=%s status=%s client_ip=%s", referenceId, status, callerIp))
|
||||
return
|
||||
}
|
||||
|
||||
paymentStatus := event.GetObjectValue("payment_status")
|
||||
if paymentStatus != "paid" {
|
||||
log.Printf("Stripe Checkout 支付尚未完成,payment_status: %s, ref: %s(等待异步支付结果)", paymentStatus, referenceId)
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe Checkout 支付未完成,等待异步结果 trade_no=%s payment_status=%s client_ip=%s", referenceId, paymentStatus, callerIp))
|
||||
return
|
||||
}
|
||||
|
||||
fulfillOrder(event, referenceId, customerId)
|
||||
fulfillOrder(ctx, event, referenceId, customerId, callerIp)
|
||||
}
|
||||
|
||||
// sessionAsyncPaymentSucceeded handles delayed payment methods (bank transfer, SEPA, etc.)
|
||||
// that confirm payment after the checkout session completes.
|
||||
func sessionAsyncPaymentSucceeded(event stripe.Event) {
|
||||
func sessionAsyncPaymentSucceeded(ctx context.Context, event stripe.Event, callerIp string) {
|
||||
customerId := event.GetObjectValue("customer")
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
log.Printf("Stripe 异步支付成功: %s", referenceId)
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 异步支付成功 trade_no=%s client_ip=%s", referenceId, callerIp))
|
||||
|
||||
fulfillOrder(event, referenceId, customerId)
|
||||
fulfillOrder(ctx, event, referenceId, customerId, callerIp)
|
||||
}
|
||||
|
||||
// sessionAsyncPaymentFailed marks orders as failed when delayed payment methods
|
||||
// ultimately fail (e.g. bank transfer not received, SEPA rejected).
|
||||
func sessionAsyncPaymentFailed(event stripe.Event) {
|
||||
func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp string) {
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
log.Printf("Stripe 异步支付失败: %s", referenceId)
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败 trade_no=%s client_ip=%s", referenceId, callerIp))
|
||||
|
||||
if len(referenceId) == 0 {
|
||||
log.Println("异步支付失败事件未提供支付单号")
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败事件缺少订单号 client_ip=%s", callerIp))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -230,32 +234,32 @@ func sessionAsyncPaymentFailed(event stripe.Event) {
|
||||
|
||||
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||
if topUp == nil {
|
||||
log.Println("异步支付失败,充值订单不存在:", referenceId)
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但本地订单不存在 trade_no=%s client_ip=%s", referenceId, callerIp))
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodStripe {
|
||||
log.Printf("异步支付失败,订单支付方式不匹配: %s, ref: %s", topUp.PaymentMethod, referenceId)
|
||||
if topUp.PaymentProvider != model.PaymentProviderStripe {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付网关不匹配 trade_no=%s payment_provider=%s client_ip=%s", referenceId, topUp.PaymentProvider, callerIp))
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
log.Printf("异步支付失败,订单状态非pending: %s, ref: %s", topUp.Status, referenceId)
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 异步支付失败但订单状态非 pending,忽略处理 trade_no=%s status=%s client_ip=%s", referenceId, topUp.Status, callerIp))
|
||||
return
|
||||
}
|
||||
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
if err := topUp.Update(); err != nil {
|
||||
log.Printf("标记充值订单失败出错: %v, ref: %s", err, referenceId)
|
||||
logger.LogError(ctx, fmt.Sprintf("Stripe 标记充值订单失败状态失败 trade_no=%s client_ip=%s error=%q", referenceId, callerIp, err.Error()))
|
||||
return
|
||||
}
|
||||
log.Printf("充值订单已标记为失败: %s", referenceId)
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值订单已标记为失败 trade_no=%s client_ip=%s", referenceId, callerIp))
|
||||
}
|
||||
|
||||
// fulfillOrder is the shared logic for crediting quota after payment is confirmed.
|
||||
func fulfillOrder(event stripe.Event, referenceId string, customerId string) {
|
||||
func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, customerId string, callerIp string) {
|
||||
if len(referenceId) == 0 {
|
||||
log.Println("未提供支付单号")
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 完成订单时缺少订单号 client_ip=%s", callerIp))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -267,65 +271,60 @@ func fulfillOrder(event stripe.Event, referenceId string, customerId string) {
|
||||
"currency": strings.ToUpper(event.GetObjectValue("currency")),
|
||||
"event_type": string(event.Type),
|
||||
}
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload)); err == nil {
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentProviderStripe, ""); err == nil {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp))
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
log.Println("complete subscription order failed:", err.Error(), referenceId)
|
||||
logger.LogError(ctx, fmt.Sprintf("Stripe 订阅订单处理失败 trade_no=%s event_type=%s client_ip=%s error=%q", referenceId, string(event.Type), callerIp, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
err := model.Recharge(referenceId, customerId)
|
||||
err := model.Recharge(referenceId, customerId, callerIp)
|
||||
if err != nil {
|
||||
log.Println(err.Error(), referenceId)
|
||||
logger.LogError(ctx, fmt.Sprintf("Stripe 充值处理失败 trade_no=%s event_type=%s client_ip=%s error=%q", referenceId, string(event.Type), callerIp, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
||||
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
||||
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值成功 trade_no=%s amount_total=%.2f currency=%s event_type=%s client_ip=%s", referenceId, total/100, currency, string(event.Type), callerIp))
|
||||
}
|
||||
|
||||
func sessionExpired(event stripe.Event) {
|
||||
func sessionExpired(ctx context.Context, event stripe.Event) {
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "expired" != status {
|
||||
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe checkout.expired 状态异常,忽略处理 trade_no=%s status=%s", referenceId, status))
|
||||
return
|
||||
}
|
||||
|
||||
if len(referenceId) == 0 {
|
||||
log.Println("未提供支付单号")
|
||||
logger.LogWarn(ctx, "Stripe checkout.expired 缺少订单号")
|
||||
return
|
||||
}
|
||||
|
||||
// Subscription order expiration
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
if err := model.ExpireSubscriptionOrder(referenceId); err == nil {
|
||||
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentProviderStripe); err == nil {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId))
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
log.Println("过期订阅订单失败", referenceId, ", err:", err.Error())
|
||||
logger.LogError(ctx, fmt.Sprintf("Stripe 订阅订单过期处理失败 trade_no=%s error=%q", referenceId, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||
if topUp == nil {
|
||||
log.Println("充值订单不存在", referenceId)
|
||||
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentProviderStripe, common.TopUpStatusExpired)
|
||||
if errors.Is(err, model.ErrTopUpNotFound) {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId))
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
log.Println("充值订单状态错误", referenceId)
|
||||
}
|
||||
|
||||
topUp.Status = common.TopUpStatusExpired
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
||||
logger.LogError(ctx, fmt.Sprintf("Stripe 充值订单过期处理失败 trade_no=%s error=%q", referenceId, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("充值订单已过期", referenceId)
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值订单已过期 trade_no=%s", referenceId))
|
||||
}
|
||||
|
||||
// genStripeLink generates a Stripe Checkout session URL for payment.
|
||||
|
||||
+80
-42
@@ -1,14 +1,15 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
@@ -99,28 +100,57 @@ type WaffoPayRequest struct {
|
||||
PayMethodName string `json:"pay_method_name"` // Deprecated: 兼容旧前端,优先使用 pay_method_index
|
||||
}
|
||||
|
||||
func RequestWaffoAmount(c *gin.Context) {
|
||||
var req WaffoPayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
|
||||
waffoMinTopup := int64(setting.WaffoMinTopUp)
|
||||
if req.Amount < waffoMinTopup {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
|
||||
payMoney := getWaffoPayMoney(float64(req.Amount), group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
}
|
||||
|
||||
// RequestWaffoPay 创建 Waffo 支付订单
|
||||
func RequestWaffoPay(c *gin.Context) {
|
||||
if !setting.WaffoEnabled {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "Waffo 支付未启用"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo 支付未启用"})
|
||||
return
|
||||
}
|
||||
|
||||
var req WaffoPayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
waffoMinTopup := int64(setting.WaffoMinTopUp)
|
||||
if req.Amount < waffoMinTopup {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil || user == nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "用户不存在"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "用户不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -131,8 +161,8 @@ func RequestWaffoPay(c *gin.Context) {
|
||||
// 新协议:按索引查找
|
||||
idx := *req.PayMethodIndex
|
||||
if idx < 0 || idx >= len(methods) {
|
||||
log.Printf("Waffo 无效的支付方式索引: %d, UserId=%d, 可用范围: [0, %d)", idx, id, len(methods))
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"})
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 支付方式索引无效 user_id=%d pay_method_index=%d method_count=%d", id, idx, len(methods)))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付方式"})
|
||||
return
|
||||
}
|
||||
resolvedPayMethodType = methods[idx].PayMethodType
|
||||
@@ -149,8 +179,8 @@ func RequestWaffoPay(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
log.Printf("Waffo 无效的支付方式: PayMethodType=%s, PayMethodName=%s, UserId=%d", req.PayMethodType, req.PayMethodName, id)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"})
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 支付方式无效 user_id=%d pay_method_type=%s pay_method_name=%q", id, req.PayMethodType, req.PayMethodName))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付方式"})
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -159,7 +189,7 @@ func RequestWaffoPay(c *gin.Context) {
|
||||
group, _ := model.GetUserGroup(id, true)
|
||||
payMoney := getWaffoPayMoney(float64(req.Amount), group)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -178,26 +208,27 @@ func RequestWaffoPay(c *gin.Context) {
|
||||
|
||||
// 创建本地订单
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: merchantOrderId,
|
||||
PaymentMethod: "waffo",
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: merchantOrderId,
|
||||
PaymentMethod: model.PaymentMethodWaffo,
|
||||
PaymentProvider: model.PaymentProviderWaffo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := topUp.Insert(); err != nil {
|
||||
log.Printf("Waffo 创建本地订单失败: %v", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
|
||||
sdk, err := getWaffoSDK()
|
||||
if err != nil {
|
||||
log.Printf("Waffo SDK 初始化失败: %v", err)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo SDK 初始化失败 user_id=%d trade_no=%s error=%q", id, merchantOrderId, err.Error()))
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付配置错误"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "支付配置错误"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -238,29 +269,29 @@ func RequestWaffoPay(c *gin.Context) {
|
||||
}
|
||||
resp, err := sdk.Order().Create(c.Request.Context(), createParams, nil)
|
||||
if err != nil {
|
||||
log.Printf("Waffo 创建订单失败: %v", err)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建订单失败 user_id=%d trade_no=%s error=%q", id, merchantOrderId, err.Error()))
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
if !resp.IsSuccess() {
|
||||
log.Printf("Waffo 创建订单业务失败: [%s] %s, 完整响应: %+v", resp.Code, resp.Message, resp)
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 创建订单业务失败 user_id=%d trade_no=%s code=%s message=%q response=%q", id, merchantOrderId, resp.Code, resp.Message, common.GetJsonString(resp)))
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
orderData := resp.GetData()
|
||||
log.Printf("Waffo 订单创建成功 - 用户: %d, 订单: %s, 金额: %.2f", id, merchantOrderId, payMoney)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 充值订单创建成功 user_id=%d trade_no=%s amount=%d money=%.2f pay_method_type=%s pay_method_name=%q", id, merchantOrderId, req.Amount, payMoney, resolvedPayMethodType, resolvedPayMethodName))
|
||||
|
||||
paymentUrl := orderData.FetchRedirectURL()
|
||||
if paymentUrl == "" {
|
||||
paymentUrl = orderData.OrderAction
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"payment_url": paymentUrl,
|
||||
@@ -287,16 +318,22 @@ type webhookSubscriptionInfo struct {
|
||||
|
||||
// WaffoWebhook 处理 Waffo 回调通知(支付/退款/订阅)
|
||||
func WaffoWebhook(c *gin.Context) {
|
||||
if !isWaffoWebhookEnabled() {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("Waffo Webhook 读取 body 失败: %v", err)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
sdk, err := getWaffoSDK()
|
||||
if err != nil {
|
||||
log.Printf("Waffo Webhook SDK 初始化失败: %v", err)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook SDK 初始化失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -304,17 +341,18 @@ func WaffoWebhook(c *gin.Context) {
|
||||
wh := sdk.Webhook()
|
||||
bodyStr := string(bodyBytes)
|
||||
signature := c.GetHeader("X-SIGNATURE")
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, bodyStr))
|
||||
|
||||
// 验证请求签名
|
||||
if !wh.VerifySignature(bodyStr, signature) {
|
||||
log.Printf("Waffo webhook 签名验证失败")
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo webhook 验签失败 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, bodyStr))
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var event core.WebhookEvent
|
||||
if err := common.Unmarshal(bodyBytes, &event); err != nil {
|
||||
log.Printf("Waffo Webhook 解析失败: %v", err)
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook 解析失败 path=%q client_ip=%s error=%q body=%q", c.Request.RequestURI, c.ClientIP(), err.Error(), bodyStr))
|
||||
sendWaffoWebhookResponse(c, wh, false, "invalid payload")
|
||||
return
|
||||
}
|
||||
@@ -324,14 +362,14 @@ func WaffoWebhook(c *gin.Context) {
|
||||
// 解析为扩展类型,区分普通支付和订阅支付
|
||||
var payload webhookPayloadWithSubInfo
|
||||
if err := common.Unmarshal(bodyBytes, &payload); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 支付回调载荷解析失败 event_type=%s client_ip=%s error=%q body=%q", event.EventType, c.ClientIP(), err.Error(), bodyStr))
|
||||
sendWaffoWebhookResponse(c, wh, false, "invalid payment payload")
|
||||
return
|
||||
}
|
||||
log.Printf("Waffo Webhook - EventType: %s, MerchantOrderId: %s, OrderStatus: %s",
|
||||
event.EventType, payload.Result.MerchantOrderID, payload.Result.OrderStatus)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 验签并解析成功 event_type=%s merchant_order_id=%s order_status=%s client_ip=%s", event.EventType, payload.Result.MerchantOrderID, payload.Result.OrderStatus, c.ClientIP()))
|
||||
handleWaffoPayment(c, wh, &payload.Result.PaymentNotificationResult)
|
||||
default:
|
||||
log.Printf("Waffo Webhook 未知事件: %s", event.EventType)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 忽略事件 event_type=%s client_ip=%s", event.EventType, c.ClientIP()))
|
||||
sendWaffoWebhookResponse(c, wh, true, "")
|
||||
}
|
||||
}
|
||||
@@ -339,13 +377,13 @@ func WaffoWebhook(c *gin.Context) {
|
||||
// handleWaffoPayment 处理支付完成通知
|
||||
func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.PaymentNotificationResult) {
|
||||
if result.OrderStatus != "PAY_SUCCESS" {
|
||||
log.Printf("Waffo 订单状态非成功: %s, 订单: %s", result.OrderStatus, result.MerchantOrderID)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP()))
|
||||
// 终态失败订单标记为 failed,避免永远停在 pending
|
||||
if result.MerchantOrderID != "" {
|
||||
if topUp := model.GetTopUpByTradeNo(result.MerchantOrderID); topUp != nil &&
|
||||
topUp.Status == common.TopUpStatusPending {
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentProviderWaffo, common.TopUpStatusFailed); err != nil &&
|
||||
!errors.Is(err, model.ErrTopUpNotFound) &&
|
||||
!errors.Is(err, model.ErrTopUpStatusInvalid) {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error()))
|
||||
}
|
||||
}
|
||||
sendWaffoWebhookResponse(c, wh, true, "")
|
||||
@@ -357,13 +395,13 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa
|
||||
LockOrder(merchantOrderId)
|
||||
defer UnlockOrder(merchantOrderId)
|
||||
|
||||
if err := model.RechargeWaffo(merchantOrderId); err != nil {
|
||||
log.Printf("Waffo 充值处理失败: %v, 订单: %s", err, merchantOrderId)
|
||||
if err := model.RechargeWaffo(merchantOrderId, c.ClientIP()); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 充值处理失败 trade_no=%s client_ip=%s error=%q", merchantOrderId, c.ClientIP(), err.Error()))
|
||||
sendWaffoWebhookResponse(c, wh, false, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Waffo 充值成功 - 订单: %s", merchantOrderId)
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 充值成功 trade_no=%s client_ip=%s", merchantOrderId, c.ClientIP()))
|
||||
sendWaffoWebhookResponse(c, wh, true, "")
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,260 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/thanhpk/randstr"
|
||||
)
|
||||
|
||||
type WaffoPancakePayRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
}
|
||||
|
||||
func RequestWaffoPancakeAmount(c *gin.Context) {
|
||||
var req WaffoPancakePayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Amount < int64(setting.WaffoPancakeMinTopUp) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", setting.WaffoPancakeMinTopUp)})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
|
||||
payMoney := getWaffoPancakePayMoney(req.Amount, group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success", "data": fmt.Sprintf("%.2f", payMoney)})
|
||||
}
|
||||
|
||||
func getWaffoPancakePayMoney(amount int64, group string) float64 {
|
||||
dAmount := decimal.NewFromInt(amount)
|
||||
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
|
||||
dAmount = dAmount.Div(decimal.NewFromFloat(common.QuotaPerUnit))
|
||||
}
|
||||
|
||||
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
|
||||
discount := 1.0
|
||||
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok && ds > 0 {
|
||||
discount = ds
|
||||
}
|
||||
|
||||
payMoney := dAmount.
|
||||
Mul(decimal.NewFromFloat(setting.WaffoPancakeUnitPrice)).
|
||||
Mul(decimal.NewFromFloat(topupGroupRatio)).
|
||||
Mul(decimal.NewFromFloat(discount))
|
||||
|
||||
return payMoney.InexactFloat64()
|
||||
}
|
||||
|
||||
func normalizeWaffoPancakeTopUpAmount(amount int64) int64 {
|
||||
if operation_setting.GetQuotaDisplayType() != operation_setting.QuotaDisplayTypeTokens {
|
||||
return amount
|
||||
}
|
||||
|
||||
normalized := decimal.NewFromInt(amount).
|
||||
Div(decimal.NewFromFloat(common.QuotaPerUnit)).
|
||||
IntPart()
|
||||
if normalized < 1 {
|
||||
return 1
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func formatWaffoPancakeAmount(payMoney float64) string {
|
||||
return decimal.NewFromFloat(payMoney).StringFixed(2)
|
||||
}
|
||||
|
||||
func getWaffoPancakeBuyerEmail(user *model.User) string {
|
||||
if user != nil && strings.TrimSpace(user.Email) != "" {
|
||||
return user.Email
|
||||
}
|
||||
if user != nil {
|
||||
return fmt.Sprintf("%d@new-api.local", user.Id)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getWaffoPancakeReturnURL() string {
|
||||
if strings.TrimSpace(setting.WaffoPancakeReturnURL) != "" {
|
||||
return setting.WaffoPancakeReturnURL
|
||||
}
|
||||
return strings.TrimRight(system_setting.ServerAddress, "/") + "/console/topup?show_history=true"
|
||||
}
|
||||
|
||||
func RequestWaffoPancakePay(c *gin.Context) {
|
||||
if !setting.WaffoPancakeEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 支付未启用"})
|
||||
return
|
||||
}
|
||||
currentWebhookKey := setting.WaffoPancakeWebhookPublicKey
|
||||
if setting.WaffoPancakeSandbox {
|
||||
currentWebhookKey = setting.WaffoPancakeWebhookTestKey
|
||||
}
|
||||
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" ||
|
||||
strings.TrimSpace(currentWebhookKey) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakeStoreID) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakeProductID) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 配置不完整"})
|
||||
return
|
||||
}
|
||||
|
||||
var req WaffoPancakePayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
if req.Amount < int64(setting.WaffoPancakeMinTopUp) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", setting.WaffoPancakeMinTopUp)})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil || user == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "用户不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
|
||||
payMoney := getWaffoPancakePayMoney(req.Amount, group)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
|
||||
tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := topUp.Insert(); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
|
||||
expiresInSeconds := 45 * 60
|
||||
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
|
||||
StoreID: setting.WaffoPancakeStoreID,
|
||||
ProductID: setting.WaffoPancakeProductID,
|
||||
ProductType: "onetime",
|
||||
Currency: strings.ToUpper(strings.TrimSpace(setting.WaffoPancakeCurrency)),
|
||||
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
|
||||
Amount: formatWaffoPancakeAmount(payMoney),
|
||||
TaxIncluded: false,
|
||||
TaxCategory: "saas",
|
||||
},
|
||||
BuyerEmail: getWaffoPancakeBuyerEmail(user),
|
||||
SuccessURL: getWaffoPancakeReturnURL(),
|
||||
ExpiresInSeconds: &expiresInSeconds,
|
||||
})
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建结账会话失败 user_id=%d trade_no=%s error=%q", id, tradeNo, err.Error()))
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值订单创建成功 user_id=%d trade_no=%s session_id=%s amount=%d money=%.2f", id, tradeNo, session.SessionID, req.Amount, payMoney))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"checkout_url": session.CheckoutURL,
|
||||
"session_id": session.SessionID,
|
||||
"expires_at": session.ExpiresAt,
|
||||
"order_id": tradeNo,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func WaffoPancakeWebhook(c *gin.Context) {
|
||||
if !isWaffoPancakeWebhookEnabled() {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
|
||||
c.String(http.StatusForbidden, "webhook disabled")
|
||||
return
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
c.String(http.StatusBadRequest, "bad request")
|
||||
return
|
||||
}
|
||||
|
||||
signature := c.GetHeader("X-Waffo-Signature")
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
|
||||
|
||||
event, err := service.VerifyConfiguredWaffoPancakeWebhook(string(bodyBytes), signature)
|
||||
if err != nil {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签失败 path=%q client_ip=%s signature=%q body=%q error=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes), err.Error()))
|
||||
c.String(http.StatusUnauthorized, "invalid signature")
|
||||
return
|
||||
}
|
||||
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签成功 event_type=%s event_id=%s order_id=%s client_ip=%s", event.NormalizedEventType(), event.ID, event.Data.OrderID, c.ClientIP()))
|
||||
if event.NormalizedEventType() != "order.completed" {
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
tradeNo, err := service.ResolveWaffoPancakeTradeNo(event)
|
||||
if err != nil {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 订单号映射失败 event_id=%s order_id=%s error=%q", event.ID, event.Data.OrderID, err.Error()))
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
LockOrder(tradeNo)
|
||||
defer UnlockOrder(tradeNo)
|
||||
|
||||
if err := model.RechargeWaffoPancake(tradeNo); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值处理失败 trade_no=%s event_id=%s order_id=%s client_ip=%s error=%q", tradeNo, event.ID, event.Data.OrderID, c.ClientIP(), err.Error()))
|
||||
c.String(http.StatusInternalServerError, "retry")
|
||||
return
|
||||
}
|
||||
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值成功 trade_no=%s event_id=%s order_id=%s client_ip=%s", tradeNo, event.ID, event.Data.OrderID, c.ClientIP()))
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFormatWaffoPancakeAmount_UsesDisplayPriceString(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
amount float64
|
||||
expected string
|
||||
}{
|
||||
{name: "whole amount", amount: 29, expected: "29.00"},
|
||||
{name: "decimal amount", amount: 29.9, expected: "29.90"},
|
||||
{name: "round half up to cents", amount: 29.999, expected: "30.00"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
require.Equal(t, tc.expected, formatWaffoPancakeAmount(tc.amount))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWaffoPancakePayMoney(t *testing.T) {
|
||||
originalUnitPrice := setting.WaffoPancakeUnitPrice
|
||||
originalQuotaDisplayType := operation_setting.GetGeneralSetting().QuotaDisplayType
|
||||
originalDiscounts := make(map[int]float64, len(operation_setting.GetPaymentSetting().AmountDiscount))
|
||||
for k, v := range operation_setting.GetPaymentSetting().AmountDiscount {
|
||||
originalDiscounts[k] = v
|
||||
}
|
||||
originalTopupGroupRatio := common.TopupGroupRatio2JSONString()
|
||||
|
||||
t.Cleanup(func() {
|
||||
setting.WaffoPancakeUnitPrice = originalUnitPrice
|
||||
operation_setting.GetGeneralSetting().QuotaDisplayType = originalQuotaDisplayType
|
||||
operation_setting.GetPaymentSetting().AmountDiscount = originalDiscounts
|
||||
require.NoError(t, common.UpdateTopupGroupRatioByJSONString(originalTopupGroupRatio))
|
||||
})
|
||||
|
||||
setting.WaffoPancakeUnitPrice = 2.5
|
||||
operation_setting.GetPaymentSetting().AmountDiscount = map[int]float64{
|
||||
10: 0.8,
|
||||
int(common.QuotaPerUnit * 3): 0.5,
|
||||
20: 0,
|
||||
}
|
||||
require.NoError(t, common.UpdateTopupGroupRatioByJSONString(`{"default":1,"vip":1.2}`))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
amount int64
|
||||
group string
|
||||
quotaDisplayType string
|
||||
expected float64
|
||||
}{
|
||||
{
|
||||
name: "currency display applies unit price group ratio and discount",
|
||||
amount: 10,
|
||||
group: "vip",
|
||||
quotaDisplayType: operation_setting.QuotaDisplayTypeUSD,
|
||||
expected: 24,
|
||||
},
|
||||
{
|
||||
name: "tokens display converts quota to display units before pricing",
|
||||
amount: int64(common.QuotaPerUnit * 3),
|
||||
group: "vip",
|
||||
quotaDisplayType: operation_setting.QuotaDisplayTypeTokens,
|
||||
expected: 4.5,
|
||||
},
|
||||
{
|
||||
name: "non-positive discount falls back to no discount",
|
||||
amount: 20,
|
||||
group: "default",
|
||||
quotaDisplayType: operation_setting.QuotaDisplayTypeUSD,
|
||||
expected: 50,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
operation_setting.GetGeneralSetting().QuotaDisplayType = tc.quotaDisplayType
|
||||
actual := getWaffoPancakePayMoney(tc.amount, tc.group)
|
||||
require.InDelta(t, tc.expected, actual, 0.000001)
|
||||
})
|
||||
}
|
||||
}
|
||||
+8
-4
@@ -2,7 +2,6 @@ package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -542,10 +541,15 @@ func AdminDisable2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
// 记录操作日志:管理员身份通过 admin_info 传递,避免在非管理员可见的日志内容中泄露。
|
||||
adminId := c.GetInt("id")
|
||||
model.RecordLog(userId, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
|
||||
adminName := c.GetString("username")
|
||||
adminInfo := map[string]interface{}{
|
||||
"admin_id": adminId,
|
||||
"admin_username": adminName,
|
||||
}
|
||||
model.RecordLogWithAdminInfo(userId, model.LogTypeManage,
|
||||
"管理员强制禁用了用户的两步验证", adminInfo)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
|
||||
+29
-6
@@ -91,6 +91,7 @@ func Login(c *gin.Context) {
|
||||
|
||||
// setup session & cookies and then return user info
|
||||
func setupLogin(user *model.User, c *gin.Context) {
|
||||
model.UpdateUserLastLoginAt(user.Id)
|
||||
session := sessions.Default(c)
|
||||
session.Set("id", user.Id)
|
||||
session.Set("username", user.Username)
|
||||
@@ -891,6 +892,11 @@ func ManageUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// 删除用户后,强制清理 Redis 中所有该用户令牌的缓存,
|
||||
// 避免已缓存的令牌在 TTL 过期前仍能通过 TokenAuth 校验。
|
||||
if err := model.InvalidateUserTokensCache(user.Id); err != nil {
|
||||
common.SysLog(fmt.Sprintf("failed to invalidate tokens cache for user %d: %s", user.Id, err.Error()))
|
||||
}
|
||||
case "promote":
|
||||
if myRole != common.RoleRootUser {
|
||||
common.ApiErrorI18n(c, i18n.MsgUserAdminCannotPromote)
|
||||
@@ -913,6 +919,11 @@ func ManageUser(c *gin.Context) {
|
||||
user.Role = common.RoleCommonUser
|
||||
case "add_quota":
|
||||
adminName := c.GetString("username")
|
||||
adminId := c.GetInt("id")
|
||||
adminInfo := map[string]interface{}{
|
||||
"admin_id": adminId,
|
||||
"admin_username": adminName,
|
||||
}
|
||||
switch req.Mode {
|
||||
case "add":
|
||||
if req.Value <= 0 {
|
||||
@@ -923,8 +934,8 @@ func ManageUser(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.RecordLog(user.Id, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员(%s)增加用户额度 %s", adminName, logger.LogQuota(req.Value)))
|
||||
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员增加用户额度 %s", logger.LogQuota(req.Value)), adminInfo)
|
||||
case "subtract":
|
||||
if req.Value <= 0 {
|
||||
common.ApiErrorI18n(c, i18n.MsgUserQuotaChangeZero)
|
||||
@@ -934,16 +945,16 @@ func ManageUser(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.RecordLog(user.Id, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员(%s)减少用户额度 %s", adminName, logger.LogQuota(req.Value)))
|
||||
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员减少用户额度 %s", logger.LogQuota(req.Value)), adminInfo)
|
||||
case "override":
|
||||
oldQuota := user.Quota
|
||||
if err := model.DB.Model(&model.User{}).Where("id = ?", user.Id).Update("quota", req.Value).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.RecordLog(user.Id, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员(%s)覆盖用户额度从 %s 为 %s", adminName, logger.LogQuota(oldQuota), logger.LogQuota(req.Value)))
|
||||
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员覆盖用户额度从 %s 为 %s", logger.LogQuota(oldQuota), logger.LogQuota(req.Value)), adminInfo)
|
||||
default:
|
||||
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
|
||||
return
|
||||
@@ -959,6 +970,18 @@ func ManageUser(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// 禁用 / 角色调整后,强制失效用户缓存与其全部令牌缓存,
|
||||
// 避免在 Redis TTL 过期前仍使用旧状态(尤其是禁用后仍可发起请求的问题)。
|
||||
// InvalidateUserCache 会让下一次 GetUserCache 从数据库重新加载,
|
||||
// InvalidateUserTokensCache 则确保令牌侧的缓存也同步刷新。
|
||||
if req.Action == "disable" || req.Action == "promote" || req.Action == "demote" {
|
||||
if err := model.InvalidateUserCache(user.Id); err != nil {
|
||||
common.SysLog(fmt.Sprintf("failed to invalidate user cache for user %d: %s", user.Id, err.Error()))
|
||||
}
|
||||
if err := model.InvalidateUserTokensCache(user.Id); err != nil {
|
||||
common.SysLog(fmt.Sprintf("failed to invalidate tokens cache for user %d: %s", user.Id, err.Error()))
|
||||
}
|
||||
}
|
||||
clearUser := model.User{
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
|
||||
Reference in New Issue
Block a user