fix: add PaymentProvider field to prevent cross-gateway callback attacks

EPay allows users to switch payment methods (e.g. wxpay→alipay) during
checkout, causing callback rejection. Replace fragile blocklist guard
with a PaymentProvider field on TopUp and SubscriptionOrder that
identifies which gateway created the order.
This commit is contained in:
CaIon
2026-04-24 22:16:16 +08:00
parent 8993386743
commit a7c38ec851
12 changed files with 163 additions and 180 deletions
+8 -7
View File
@@ -83,13 +83,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
// create pending order first // create pending order first
order := &model.SubscriptionOrder{ order := &model.SubscriptionOrder{
UserId: userId, UserId: userId,
PlanId: plan.Id, PlanId: plan.Id,
Money: plan.PriceAmount, Money: plan.PriceAmount,
TradeNo: referenceId, TradeNo: referenceId,
PaymentMethod: model.PaymentMethodCreem, PaymentMethod: model.PaymentMethodCreem,
CreateTime: time.Now().Unix(), PaymentProvider: model.PaymentProviderCreem,
Status: common.TopUpStatusPending, CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
} }
if err := order.Insert(); err != nil { if err := order.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
+11 -10
View File
@@ -82,13 +82,14 @@ func SubscriptionRequestEpay(c *gin.Context) {
} }
order := &model.SubscriptionOrder{ order := &model.SubscriptionOrder{
UserId: userId, UserId: userId,
PlanId: plan.Id, PlanId: plan.Id,
Money: plan.PriceAmount, Money: plan.PriceAmount,
TradeNo: tradeNo, TradeNo: tradeNo,
PaymentMethod: req.PaymentMethod, PaymentMethod: req.PaymentMethod,
CreateTime: time.Now().Unix(), PaymentProvider: model.PaymentProviderEpay,
Status: common.TopUpStatusPending, CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
} }
if err := order.Insert(); err != nil { if err := order.Insert(); err != nil {
common.ApiErrorMsg(c, "创建订单失败") common.ApiErrorMsg(c, "创建订单失败")
@@ -104,7 +105,7 @@ func SubscriptionRequestEpay(c *gin.Context) {
ReturnUrl: returnUrl, ReturnUrl: returnUrl,
}) })
if err != nil { if err != nil {
_ = model.ExpireSubscriptionOrder(tradeNo, req.PaymentMethod) _ = model.ExpireSubscriptionOrder(tradeNo, model.PaymentProviderEpay)
common.ApiErrorMsg(c, "拉起支付失败") common.ApiErrorMsg(c, "拉起支付失败")
return return
} }
@@ -156,7 +157,7 @@ func SubscriptionEpayNotify(c *gin.Context) {
LockOrder(verifyInfo.ServiceTradeNo) LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo) defer UnlockOrder(verifyInfo.ServiceTradeNo)
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil { if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
_, _ = c.Writer.Write([]byte("fail")) _, _ = c.Writer.Write([]byte("fail"))
return return
} }
@@ -205,7 +206,7 @@ func SubscriptionEpayReturn(c *gin.Context) {
if verifyInfo.TradeStatus == epay.StatusTradeSuccess { if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
LockOrder(verifyInfo.ServiceTradeNo) LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo) defer UnlockOrder(verifyInfo.ServiceTradeNo)
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil { if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail") c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail")
return return
} }
+8 -7
View File
@@ -84,13 +84,14 @@ func SubscriptionRequestStripePay(c *gin.Context) {
} }
order := &model.SubscriptionOrder{ order := &model.SubscriptionOrder{
UserId: userId, UserId: userId,
PlanId: plan.Id, PlanId: plan.Id,
Money: plan.PriceAmount, Money: plan.PriceAmount,
TradeNo: referenceId, TradeNo: referenceId,
PaymentMethod: model.PaymentMethodStripe, PaymentMethod: model.PaymentMethodStripe,
CreateTime: time.Now().Unix(), PaymentProvider: model.PaymentProviderStripe,
Status: common.TopUpStatusPending, CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
} }
if err := order.Insert(); err != nil { if err := order.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"}) c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
+14 -24
View File
@@ -123,17 +123,6 @@ type AmountRequest struct {
Amount int64 `json:"amount"` Amount int64 `json:"amount"`
} }
var nonEpayPaymentMethodsForCallback = []string{
model.PaymentMethodStripe,
model.PaymentMethodCreem,
model.PaymentMethodWaffo,
model.PaymentMethodWaffoPancake,
}
func isNonEpayPaymentMethodForEpayCallback(paymentMethod string) bool {
return lo.Contains(nonEpayPaymentMethodsForCallback, paymentMethod)
}
func GetEpayClient() *epay.Client { func GetEpayClient() *epay.Client {
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" { if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
return nil return nil
@@ -248,13 +237,14 @@ func RequestEpay(c *gin.Context) {
amount = dAmount.Div(dQuotaPerUnit).IntPart() amount = dAmount.Div(dQuotaPerUnit).IntPart()
} }
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: amount, Amount: amount,
Money: payMoney, Money: payMoney,
TradeNo: tradeNo, TradeNo: tradeNo,
PaymentMethod: req.PaymentMethod, PaymentMethod: req.PaymentMethod,
CreateTime: time.Now().Unix(), PaymentProvider: model.PaymentProviderEpay,
Status: common.TopUpStatusPending, CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
} }
err = topUp.Insert() err = topUp.Insert()
if err != nil { if err != nil {
@@ -379,15 +369,15 @@ func EpayNotify(c *gin.Context) {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo))) logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo)))
return return
} }
if isNonEpayPaymentMethodForEpayCallback(topUp.PaymentMethod) { if topUp.PaymentProvider != model.PaymentProviderEpay {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP())) logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付网关不匹配 trade_no=%s order_provider=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentProvider, verifyInfo.Type, c.ClientIP()))
return
}
if topUp.PaymentMethod != verifyInfo.Type {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
return return
} }
if topUp.Status == common.TopUpStatusPending { if topUp.Status == common.TopUpStatusPending {
if topUp.PaymentMethod != verifyInfo.Type {
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 实际支付方式与订单不同 trade_no=%s order_payment_method=%s actual_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
topUp.PaymentMethod = verifyInfo.Type
}
topUp.Status = common.TopUpStatusSuccess topUp.Status = common.TopUpStatusSuccess
err := topUp.Update() err := topUp.Update()
if err != nil { if err != nil {
+9 -8
View File
@@ -106,13 +106,14 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
// 先创建订单记录,使用产品配置的金额和充值额度 // 先创建订单记录,使用产品配置的金额和充值额度
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: selectedProduct.Quota, // 充值额度 Amount: selectedProduct.Quota, // 充值额度
Money: selectedProduct.Price, // 支付金额 Money: selectedProduct.Price, // 支付金额
TradeNo: referenceId, TradeNo: referenceId,
PaymentMethod: model.PaymentMethodCreem, PaymentMethod: model.PaymentMethodCreem,
CreateTime: time.Now().Unix(), PaymentProvider: model.PaymentProviderCreem,
Status: common.TopUpStatusPending, CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
} }
err = topUp.Insert() err = topUp.Insert()
if err != nil { if err != nil {
@@ -301,7 +302,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// Try complete subscription order first // Try complete subscription order first
LockOrder(referenceId) LockOrder(referenceId)
defer UnlockOrder(referenceId) defer UnlockOrder(referenceId)
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentMethodCreem); err == nil { if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentProviderCreem, ""); err == nil {
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id)) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
c.Status(http.StatusOK) c.Status(http.StatusOK)
return return
-31
View File
@@ -1,31 +0,0 @@
package controller
import (
"testing"
"github.com/QuantumNous/new-api/model"
)
func TestIsNonEpayPaymentMethodForEpayCallback(t *testing.T) {
testCases := []struct {
name string
paymentMethod string
expectedBlocked bool
}{
{name: "stripe", paymentMethod: model.PaymentMethodStripe, expectedBlocked: true},
{name: "creem", paymentMethod: model.PaymentMethodCreem, expectedBlocked: true},
{name: "waffo", paymentMethod: model.PaymentMethodWaffo, expectedBlocked: true},
{name: "waffo pancake", paymentMethod: model.PaymentMethodWaffoPancake, expectedBlocked: true},
{name: "alipay", paymentMethod: "alipay", expectedBlocked: false},
{name: "wxpay", paymentMethod: "wxpay", expectedBlocked: false},
{name: "custom epay type", paymentMethod: "custom1", expectedBlocked: false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if actual := isNonEpayPaymentMethodForEpayCallback(tc.paymentMethod); actual != tc.expectedBlocked {
t.Fatalf("expected blocked=%v, got %v for payment method %q", tc.expectedBlocked, actual, tc.paymentMethod)
}
})
}
}
+13 -12
View File
@@ -101,13 +101,14 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
} }
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: req.Amount, Amount: req.Amount,
Money: chargedMoney, Money: chargedMoney,
TradeNo: referenceId, TradeNo: referenceId,
PaymentMethod: model.PaymentMethodStripe, PaymentMethod: model.PaymentMethodStripe,
CreateTime: time.Now().Unix(), PaymentProvider: model.PaymentProviderStripe,
Status: common.TopUpStatusPending, CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
} }
err = topUp.Insert() err = topUp.Insert()
if err != nil { if err != nil {
@@ -237,8 +238,8 @@ func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp
return return
} }
if topUp.PaymentMethod != model.PaymentMethodStripe { if topUp.PaymentProvider != model.PaymentProviderStripe {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付方式不匹配 trade_no=%s payment_method=%s client_ip=%s", referenceId, topUp.PaymentMethod, callerIp)) logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付网关不匹配 trade_no=%s payment_provider=%s client_ip=%s", referenceId, topUp.PaymentProvider, callerIp))
return return
} }
@@ -270,7 +271,7 @@ func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, c
"currency": strings.ToUpper(event.GetObjectValue("currency")), "currency": strings.ToUpper(event.GetObjectValue("currency")),
"event_type": string(event.Type), "event_type": string(event.Type),
} }
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentMethodStripe); err == nil { if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentProviderStripe, ""); err == nil {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp)) logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp))
return return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
@@ -305,7 +306,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) {
// Subscription order expiration // Subscription order expiration
LockOrder(referenceId) LockOrder(referenceId)
defer UnlockOrder(referenceId) defer UnlockOrder(referenceId)
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentMethodStripe); err == nil { if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentProviderStripe); err == nil {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId)) logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId))
return return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
@@ -313,7 +314,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) {
return return
} }
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentMethodStripe, common.TopUpStatusExpired) err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentProviderStripe, common.TopUpStatusExpired)
if errors.Is(err, model.ErrTopUpNotFound) { if errors.Is(err, model.ErrTopUpNotFound) {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId)) logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId))
return return
+9 -8
View File
@@ -208,13 +208,14 @@ func RequestWaffoPay(c *gin.Context) {
// 创建本地订单 // 创建本地订单
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: amount, Amount: amount,
Money: payMoney, Money: payMoney,
TradeNo: merchantOrderId, TradeNo: merchantOrderId,
PaymentMethod: model.PaymentMethodWaffo, PaymentMethod: model.PaymentMethodWaffo,
CreateTime: time.Now().Unix(), PaymentProvider: model.PaymentProviderWaffo,
Status: common.TopUpStatusPending, CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
} }
if err := topUp.Insert(); err != nil { if err := topUp.Insert(); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error())) logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error()))
@@ -379,7 +380,7 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP())) logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP()))
// 终态失败订单标记为 failed,避免永远停在 pending // 终态失败订单标记为 failed,避免永远停在 pending
if result.MerchantOrderID != "" { if result.MerchantOrderID != "" {
if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentMethodWaffo, common.TopUpStatusFailed); err != nil && if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentProviderWaffo, common.TopUpStatusFailed); err != nil &&
!errors.Is(err, model.ErrTopUpNotFound) && !errors.Is(err, model.ErrTopUpNotFound) &&
!errors.Is(err, model.ErrTopUpStatusInvalid) { !errors.Is(err, model.ErrTopUpStatusInvalid) {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error())) logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error()))
+8 -7
View File
@@ -159,13 +159,14 @@ func RequestWaffoPancakePay(c *gin.Context) {
tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6)) tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount), Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
Money: payMoney, Money: payMoney,
TradeNo: tradeNo, TradeNo: tradeNo,
PaymentMethod: model.PaymentMethodWaffoPancake, PaymentMethod: model.PaymentMethodWaffoPancake,
CreateTime: time.Now().Unix(), PaymentProvider: model.PaymentProviderWaffoPancake,
Status: common.TopUpStatusPending, CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
} }
if err := topUp.Insert(); err != nil { if err := topUp.Insert(); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error())) logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error()))
+43 -41
View File
@@ -36,30 +36,32 @@ func insertSubscriptionPlanForPaymentGuardTest(t *testing.T, id int) *Subscripti
return plan return plan
} }
func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentMethod string) { func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentProvider string) {
t.Helper() t.Helper()
order := &SubscriptionOrder{ order := &SubscriptionOrder{
UserId: userID, UserId: userID,
PlanId: planID, PlanId: planID,
Money: 9.99, Money: 9.99,
TradeNo: tradeNo, TradeNo: tradeNo,
PaymentMethod: paymentMethod, PaymentMethod: paymentProvider,
Status: common.TopUpStatusPending, PaymentProvider: paymentProvider,
CreateTime: time.Now().Unix(), Status: common.TopUpStatusPending,
CreateTime: time.Now().Unix(),
} }
require.NoError(t, order.Insert()) require.NoError(t, order.Insert())
} }
func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentMethod string) { func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentProvider string) {
t.Helper() t.Helper()
topUp := &TopUp{ topUp := &TopUp{
UserId: userID, UserId: userID,
Amount: 2, Amount: 2,
Money: 9.99, Money: 9.99,
TradeNo: tradeNo, TradeNo: tradeNo,
PaymentMethod: paymentMethod, PaymentMethod: paymentProvider,
Status: common.TopUpStatusPending, PaymentProvider: paymentProvider,
CreateTime: time.Now().Unix(), Status: common.TopUpStatusPending,
CreateTime: time.Now().Unix(),
} }
require.NoError(t, topUp.Insert()) require.NoError(t, topUp.Insert())
} }
@@ -89,7 +91,7 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
truncateTables(t) truncateTables(t)
insertUserForPaymentGuardTest(t, 101, 0) insertUserForPaymentGuardTest(t, 101, 0)
insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentMethodStripe) insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentProviderStripe)
err := RechargeWaffoPancake("waffo-pancake-guard") err := RechargeWaffoPancake("waffo-pancake-guard")
require.Error(t, err) require.Error(t, err)
@@ -100,27 +102,27 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
assert.Equal(t, 0, getUserQuotaForPaymentGuardTest(t, 101)) assert.Equal(t, 0, getUserQuotaForPaymentGuardTest(t, 101))
} }
func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) { func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentProvider(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
tradeNo string tradeNo string
storedPaymentMethod string storedPaymentProvider string
expectedPaymentMethod string expectedPaymentProvider string
targetStatus string targetStatus string
}{ }{
{ {
name: "stripe expire", name: "stripe expire",
tradeNo: "stripe-expire-guard", tradeNo: "stripe-expire-guard",
storedPaymentMethod: PaymentMethodCreem, storedPaymentProvider: PaymentProviderCreem,
expectedPaymentMethod: PaymentMethodStripe, expectedPaymentProvider: PaymentProviderStripe,
targetStatus: common.TopUpStatusExpired, targetStatus: common.TopUpStatusExpired,
}, },
{ {
name: "waffo failed", name: "waffo failed",
tradeNo: "waffo-failed-guard", tradeNo: "waffo-failed-guard",
storedPaymentMethod: PaymentMethodStripe, storedPaymentProvider: PaymentProviderStripe,
expectedPaymentMethod: PaymentMethodWaffo, expectedPaymentProvider: PaymentProviderWaffo,
targetStatus: common.TopUpStatusFailed, targetStatus: common.TopUpStatusFailed,
}, },
} }
@@ -128,23 +130,23 @@ func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
truncateTables(t) truncateTables(t)
insertUserForPaymentGuardTest(t, 150, 0) insertUserForPaymentGuardTest(t, 150, 0)
insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentMethod) insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentProvider)
err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentMethod, tc.targetStatus) err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentProvider, tc.targetStatus)
require.ErrorIs(t, err, ErrPaymentMethodMismatch) require.ErrorIs(t, err, ErrPaymentMethodMismatch)
assert.Equal(t, common.TopUpStatusPending, getTopUpStatusForPaymentGuardTest(t, tc.tradeNo)) assert.Equal(t, common.TopUpStatusPending, getTopUpStatusForPaymentGuardTest(t, tc.tradeNo))
}) })
} }
} }
func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) { func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
truncateTables(t) truncateTables(t)
insertUserForPaymentGuardTest(t, 202, 0) insertUserForPaymentGuardTest(t, 202, 0)
plan := insertSubscriptionPlanForPaymentGuardTest(t, 301) plan := insertSubscriptionPlanForPaymentGuardTest(t, 301)
insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentMethodStripe) insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentProviderStripe)
err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, "alipay") err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, PaymentProviderEpay, "alipay")
require.ErrorIs(t, err, ErrPaymentMethodMismatch) require.ErrorIs(t, err, ErrPaymentMethodMismatch)
order := GetSubscriptionOrderByTradeNo("sub-guard-order") order := GetSubscriptionOrderByTradeNo("sub-guard-order")
@@ -156,14 +158,14 @@ func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T)
assert.Nil(t, topUp) assert.Nil(t, topUp)
} }
func TestExpireSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) { func TestExpireSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
truncateTables(t) truncateTables(t)
insertUserForPaymentGuardTest(t, 303, 0) insertUserForPaymentGuardTest(t, 303, 0)
plan := insertSubscriptionPlanForPaymentGuardTest(t, 401) plan := insertSubscriptionPlanForPaymentGuardTest(t, 401)
insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentMethodStripe) insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentProviderStripe)
err := ExpireSubscriptionOrder("sub-expire-guard", PaymentMethodCreem) err := ExpireSubscriptionOrder("sub-expire-guard", PaymentProviderCreem)
require.ErrorIs(t, err, ErrPaymentMethodMismatch) require.ErrorIs(t, err, ErrPaymentMethodMismatch)
order := GetSubscriptionOrderByTradeNo("sub-expire-guard") order := GetSubscriptionOrderByTradeNo("sub-expire-guard")
+15 -9
View File
@@ -198,11 +198,12 @@ type SubscriptionOrder struct {
PlanId int `json:"plan_id" gorm:"index"` PlanId int `json:"plan_id" gorm:"index"`
Money float64 `json:"money"` Money float64 `json:"money"`
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
Status string `json:"status"` PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
CreateTime int64 `json:"create_time"` Status string `json:"status"`
CompleteTime int64 `json:"complete_time"` CreateTime int64 `json:"create_time"`
CompleteTime int64 `json:"complete_time"`
ProviderPayload string `json:"provider_payload" gorm:"type:text"` ProviderPayload string `json:"provider_payload" gorm:"type:text"`
} }
@@ -505,7 +506,9 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio
} }
// Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan. // Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan.
func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentMethod string) error { // expectedPaymentProvider guards against cross-gateway callback attacks (empty skips the check).
// actualPaymentMethod updates the order's PaymentMethod to reflect the real payment type used (empty skips update).
func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentProvider string, actualPaymentMethod string) error {
if tradeNo == "" { if tradeNo == "" {
return errors.New("tradeNo is empty") return errors.New("tradeNo is empty")
} }
@@ -523,7 +526,7 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil { if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
return ErrSubscriptionOrderNotFound return ErrSubscriptionOrderNotFound
} }
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod { if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
return ErrPaymentMethodMismatch return ErrPaymentMethodMismatch
} }
if order.Status == common.TopUpStatusSuccess { if order.Status == common.TopUpStatusSuccess {
@@ -552,6 +555,9 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
if providerPayload != "" { if providerPayload != "" {
order.ProviderPayload = providerPayload order.ProviderPayload = providerPayload
} }
if actualPaymentMethod != "" && order.PaymentMethod != actualPaymentMethod {
order.PaymentMethod = actualPaymentMethod
}
if err := tx.Save(&order).Error; err != nil { if err := tx.Save(&order).Error; err != nil {
return err return err
} }
@@ -610,7 +616,7 @@ func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
return tx.Save(&topup).Error return tx.Save(&topup).Error
} }
func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error { func ExpireSubscriptionOrder(tradeNo string, expectedPaymentProvider string) error {
if tradeNo == "" { if tradeNo == "" {
return errors.New("tradeNo is empty") return errors.New("tradeNo is empty")
} }
@@ -623,7 +629,7 @@ func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil { if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
return ErrSubscriptionOrderNotFound return ErrSubscriptionOrderNotFound
} }
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod { if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
return ErrPaymentMethodMismatch return ErrPaymentMethodMismatch
} }
if order.Status != common.TopUpStatusPending { if order.Status != common.TopUpStatusPending {
+25 -16
View File
@@ -12,15 +12,16 @@ import (
) )
type TopUp struct { type TopUp struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id" gorm:"index"`
Amount int64 `json:"amount"` Amount int64 `json:"amount"`
Money float64 `json:"money"` Money float64 `json:"money"`
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
CreateTime int64 `json:"create_time"` PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
CompleteTime int64 `json:"complete_time"` CreateTime int64 `json:"create_time"`
Status string `json:"status"` CompleteTime int64 `json:"complete_time"`
Status string `json:"status"`
} }
const ( const (
@@ -30,6 +31,14 @@ const (
PaymentMethodWaffoPancake = "waffo_pancake" PaymentMethodWaffoPancake = "waffo_pancake"
) )
const (
PaymentProviderEpay = "epay"
PaymentProviderStripe = "stripe"
PaymentProviderCreem = "creem"
PaymentProviderWaffo = "waffo"
PaymentProviderWaffoPancake = "waffo_pancake"
)
var ( var (
ErrPaymentMethodMismatch = errors.New("payment method mismatch") ErrPaymentMethodMismatch = errors.New("payment method mismatch")
ErrTopUpNotFound = errors.New("topup not found") ErrTopUpNotFound = errors.New("topup not found")
@@ -68,7 +77,7 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
return topUp return topUp
} }
func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targetStatus string) error { func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentProvider string, targetStatus string) error {
if tradeNo == "" { if tradeNo == "" {
return errors.New("未提供支付单号") return errors.New("未提供支付单号")
} }
@@ -83,7 +92,7 @@ func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targ
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error; err != nil { if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error; err != nil {
return ErrTopUpNotFound return ErrTopUpNotFound
} }
if expectedPaymentMethod != "" && topUp.PaymentMethod != expectedPaymentMethod { if expectedPaymentProvider != "" && topUp.PaymentProvider != expectedPaymentProvider {
return ErrPaymentMethodMismatch return ErrPaymentMethodMismatch
} }
if topUp.Status != common.TopUpStatusPending { if topUp.Status != common.TopUpStatusPending {
@@ -114,7 +123,7 @@ func Recharge(referenceId string, customerId string, callerIp string) (err error
return errors.New("充值订单不存在") return errors.New("充值订单不存在")
} }
if topUp.PaymentMethod != PaymentMethodStripe { if topUp.PaymentProvider != PaymentProviderStripe {
return ErrPaymentMethodMismatch return ErrPaymentMethodMismatch
} }
@@ -340,7 +349,7 @@ func ManualCompleteTopUp(tradeNo string, callerIp string) error {
// 计算应充值额度: // 计算应充值额度:
// - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit // - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit
// - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit // - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit
if topUp.PaymentMethod == PaymentMethodStripe { if topUp.PaymentProvider == PaymentProviderStripe {
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart()) quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart())
} else { } else {
@@ -397,7 +406,7 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
return errors.New("充值订单不存在") return errors.New("充值订单不存在")
} }
if topUp.PaymentMethod != PaymentMethodCreem { if topUp.PaymentProvider != PaymentProviderCreem {
return ErrPaymentMethodMismatch return ErrPaymentMethodMismatch
} }
@@ -472,7 +481,7 @@ func RechargeWaffo(tradeNo string, callerIp string) (err error) {
return errors.New("充值订单不存在") return errors.New("充值订单不存在")
} }
if topUp.PaymentMethod != PaymentMethodWaffo { if topUp.PaymentProvider != PaymentProviderWaffo {
return ErrPaymentMethodMismatch return ErrPaymentMethodMismatch
} }
@@ -535,7 +544,7 @@ func RechargeWaffoPancake(tradeNo string) (err error) {
return errors.New("充值订单不存在") return errors.New("充值订单不存在")
} }
if topUp.PaymentMethod != PaymentMethodWaffoPancake { if topUp.PaymentProvider != PaymentProviderWaffoPancake {
return ErrPaymentMethodMismatch return ErrPaymentMethodMismatch
} }