fix: add PaymentProvider field to prevent cross-gateway callback attacks
EPay allows users to switch payment methods (e.g. wxpay→alipay) during checkout, causing callback rejection. Replace fragile blocklist guard with a PaymentProvider field on TopUp and SubscriptionOrder that identifies which gateway created the order.
This commit is contained in:
@@ -83,13 +83,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
|
||||
|
||||
// create pending order first
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
PaymentProvider: model.PaymentProviderCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
|
||||
@@ -82,13 +82,14 @@ func SubscriptionRequestEpay(c *gin.Context) {
|
||||
}
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
PaymentProvider: model.PaymentProviderEpay,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
common.ApiErrorMsg(c, "创建订单失败")
|
||||
@@ -104,7 +105,7 @@ func SubscriptionRequestEpay(c *gin.Context) {
|
||||
ReturnUrl: returnUrl,
|
||||
})
|
||||
if err != nil {
|
||||
_ = model.ExpireSubscriptionOrder(tradeNo, req.PaymentMethod)
|
||||
_ = model.ExpireSubscriptionOrder(tradeNo, model.PaymentProviderEpay)
|
||||
common.ApiErrorMsg(c, "拉起支付失败")
|
||||
return
|
||||
}
|
||||
@@ -156,7 +157,7 @@ func SubscriptionEpayNotify(c *gin.Context) {
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
|
||||
_, _ = c.Writer.Write([]byte("fail"))
|
||||
return
|
||||
}
|
||||
@@ -205,7 +206,7 @@ func SubscriptionEpayReturn(c *gin.Context) {
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
|
||||
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,13 +84,14 @@ func SubscriptionRequestStripePay(c *gin.Context) {
|
||||
}
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
PaymentProvider: model.PaymentProviderStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
|
||||
+14
-24
@@ -123,17 +123,6 @@ type AmountRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
}
|
||||
|
||||
var nonEpayPaymentMethodsForCallback = []string{
|
||||
model.PaymentMethodStripe,
|
||||
model.PaymentMethodCreem,
|
||||
model.PaymentMethodWaffo,
|
||||
model.PaymentMethodWaffoPancake,
|
||||
}
|
||||
|
||||
func isNonEpayPaymentMethodForEpayCallback(paymentMethod string) bool {
|
||||
return lo.Contains(nonEpayPaymentMethodsForCallback, paymentMethod)
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
|
||||
return nil
|
||||
@@ -248,13 +237,14 @@ func RequestEpay(c *gin.Context) {
|
||||
amount = dAmount.Div(dQuotaPerUnit).IntPart()
|
||||
}
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
PaymentProvider: model.PaymentProviderEpay,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
@@ -379,15 +369,15 @@ func EpayNotify(c *gin.Context) {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo)))
|
||||
return
|
||||
}
|
||||
if isNonEpayPaymentMethodForEpayCallback(topUp.PaymentMethod) {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
if topUp.PaymentMethod != verifyInfo.Type {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
|
||||
if topUp.PaymentProvider != model.PaymentProviderEpay {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付网关不匹配 trade_no=%s order_provider=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentProvider, verifyInfo.Type, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
if topUp.Status == common.TopUpStatusPending {
|
||||
if topUp.PaymentMethod != verifyInfo.Type {
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 实际支付方式与订单不同 trade_no=%s order_payment_method=%s actual_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
|
||||
topUp.PaymentMethod = verifyInfo.Type
|
||||
}
|
||||
topUp.Status = common.TopUpStatusSuccess
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
|
||||
@@ -106,13 +106,14 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
|
||||
|
||||
// 先创建订单记录,使用产品配置的金额和充值额度
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: selectedProduct.Quota, // 充值额度
|
||||
Money: selectedProduct.Price, // 支付金额
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: selectedProduct.Quota, // 充值额度
|
||||
Money: selectedProduct.Price, // 支付金额
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
PaymentProvider: model.PaymentProviderCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
@@ -301,7 +302,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
|
||||
// Try complete subscription order first
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentMethodCreem); err == nil {
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentProviderCreem, ""); err == nil {
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
)
|
||||
|
||||
func TestIsNonEpayPaymentMethodForEpayCallback(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
paymentMethod string
|
||||
expectedBlocked bool
|
||||
}{
|
||||
{name: "stripe", paymentMethod: model.PaymentMethodStripe, expectedBlocked: true},
|
||||
{name: "creem", paymentMethod: model.PaymentMethodCreem, expectedBlocked: true},
|
||||
{name: "waffo", paymentMethod: model.PaymentMethodWaffo, expectedBlocked: true},
|
||||
{name: "waffo pancake", paymentMethod: model.PaymentMethodWaffoPancake, expectedBlocked: true},
|
||||
{name: "alipay", paymentMethod: "alipay", expectedBlocked: false},
|
||||
{name: "wxpay", paymentMethod: "wxpay", expectedBlocked: false},
|
||||
{name: "custom epay type", paymentMethod: "custom1", expectedBlocked: false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if actual := isNonEpayPaymentMethodForEpayCallback(tc.paymentMethod); actual != tc.expectedBlocked {
|
||||
t.Fatalf("expected blocked=%v, got %v for payment method %q", tc.expectedBlocked, actual, tc.paymentMethod)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+13
-12
@@ -101,13 +101,14 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
||||
}
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
PaymentProvider: model.PaymentProviderStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
@@ -237,8 +238,8 @@ func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != model.PaymentMethodStripe {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付方式不匹配 trade_no=%s payment_method=%s client_ip=%s", referenceId, topUp.PaymentMethod, callerIp))
|
||||
if topUp.PaymentProvider != model.PaymentProviderStripe {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付网关不匹配 trade_no=%s payment_provider=%s client_ip=%s", referenceId, topUp.PaymentProvider, callerIp))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -270,7 +271,7 @@ func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, c
|
||||
"currency": strings.ToUpper(event.GetObjectValue("currency")),
|
||||
"event_type": string(event.Type),
|
||||
}
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentMethodStripe); err == nil {
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentProviderStripe, ""); err == nil {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp))
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
@@ -305,7 +306,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) {
|
||||
// Subscription order expiration
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentMethodStripe); err == nil {
|
||||
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentProviderStripe); err == nil {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId))
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
@@ -313,7 +314,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) {
|
||||
return
|
||||
}
|
||||
|
||||
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentMethodStripe, common.TopUpStatusExpired)
|
||||
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentProviderStripe, common.TopUpStatusExpired)
|
||||
if errors.Is(err, model.ErrTopUpNotFound) {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId))
|
||||
return
|
||||
|
||||
@@ -208,13 +208,14 @@ func RequestWaffoPay(c *gin.Context) {
|
||||
|
||||
// 创建本地订单
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: merchantOrderId,
|
||||
PaymentMethod: model.PaymentMethodWaffo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: merchantOrderId,
|
||||
PaymentMethod: model.PaymentMethodWaffo,
|
||||
PaymentProvider: model.PaymentProviderWaffo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := topUp.Insert(); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error()))
|
||||
@@ -379,7 +380,7 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP()))
|
||||
// 终态失败订单标记为 failed,避免永远停在 pending
|
||||
if result.MerchantOrderID != "" {
|
||||
if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentMethodWaffo, common.TopUpStatusFailed); err != nil &&
|
||||
if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentProviderWaffo, common.TopUpStatusFailed); err != nil &&
|
||||
!errors.Is(err, model.ErrTopUpNotFound) &&
|
||||
!errors.Is(err, model.ErrTopUpStatusInvalid) {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error()))
|
||||
|
||||
@@ -159,13 +159,14 @@ func RequestWaffoPancakePay(c *gin.Context) {
|
||||
|
||||
tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := topUp.Insert(); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error()))
|
||||
|
||||
@@ -36,30 +36,32 @@ func insertSubscriptionPlanForPaymentGuardTest(t *testing.T, id int) *Subscripti
|
||||
return plan
|
||||
}
|
||||
|
||||
func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentMethod string) {
|
||||
func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentProvider string) {
|
||||
t.Helper()
|
||||
order := &SubscriptionOrder{
|
||||
UserId: userID,
|
||||
PlanId: planID,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentMethod,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
UserId: userID,
|
||||
PlanId: planID,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentProvider,
|
||||
PaymentProvider: paymentProvider,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
}
|
||||
require.NoError(t, order.Insert())
|
||||
}
|
||||
|
||||
func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentMethod string) {
|
||||
func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentProvider string) {
|
||||
t.Helper()
|
||||
topUp := &TopUp{
|
||||
UserId: userID,
|
||||
Amount: 2,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentMethod,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
UserId: userID,
|
||||
Amount: 2,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentProvider,
|
||||
PaymentProvider: paymentProvider,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
}
|
||||
require.NoError(t, topUp.Insert())
|
||||
}
|
||||
@@ -89,7 +91,7 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 101, 0)
|
||||
insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentMethodStripe)
|
||||
insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentProviderStripe)
|
||||
|
||||
err := RechargeWaffoPancake("waffo-pancake-guard")
|
||||
require.Error(t, err)
|
||||
@@ -100,27 +102,27 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
assert.Equal(t, 0, getUserQuotaForPaymentGuardTest(t, 101))
|
||||
}
|
||||
|
||||
func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
tradeNo string
|
||||
storedPaymentMethod string
|
||||
expectedPaymentMethod string
|
||||
targetStatus string
|
||||
name string
|
||||
tradeNo string
|
||||
storedPaymentProvider string
|
||||
expectedPaymentProvider string
|
||||
targetStatus string
|
||||
}{
|
||||
{
|
||||
name: "stripe expire",
|
||||
tradeNo: "stripe-expire-guard",
|
||||
storedPaymentMethod: PaymentMethodCreem,
|
||||
expectedPaymentMethod: PaymentMethodStripe,
|
||||
targetStatus: common.TopUpStatusExpired,
|
||||
name: "stripe expire",
|
||||
tradeNo: "stripe-expire-guard",
|
||||
storedPaymentProvider: PaymentProviderCreem,
|
||||
expectedPaymentProvider: PaymentProviderStripe,
|
||||
targetStatus: common.TopUpStatusExpired,
|
||||
},
|
||||
{
|
||||
name: "waffo failed",
|
||||
tradeNo: "waffo-failed-guard",
|
||||
storedPaymentMethod: PaymentMethodStripe,
|
||||
expectedPaymentMethod: PaymentMethodWaffo,
|
||||
targetStatus: common.TopUpStatusFailed,
|
||||
name: "waffo failed",
|
||||
tradeNo: "waffo-failed-guard",
|
||||
storedPaymentProvider: PaymentProviderStripe,
|
||||
expectedPaymentProvider: PaymentProviderWaffo,
|
||||
targetStatus: common.TopUpStatusFailed,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -128,23 +130,23 @@ func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
truncateTables(t)
|
||||
insertUserForPaymentGuardTest(t, 150, 0)
|
||||
insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentMethod)
|
||||
insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentProvider)
|
||||
|
||||
err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentMethod, tc.targetStatus)
|
||||
err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentProvider, tc.targetStatus)
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
assert.Equal(t, common.TopUpStatusPending, getTopUpStatusForPaymentGuardTest(t, tc.tradeNo))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 202, 0)
|
||||
plan := insertSubscriptionPlanForPaymentGuardTest(t, 301)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentMethodStripe)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentProviderStripe)
|
||||
|
||||
err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, "alipay")
|
||||
err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, PaymentProviderEpay, "alipay")
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
|
||||
order := GetSubscriptionOrderByTradeNo("sub-guard-order")
|
||||
@@ -156,14 +158,14 @@ func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T)
|
||||
assert.Nil(t, topUp)
|
||||
}
|
||||
|
||||
func TestExpireSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestExpireSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 303, 0)
|
||||
plan := insertSubscriptionPlanForPaymentGuardTest(t, 401)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentMethodStripe)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentProviderStripe)
|
||||
|
||||
err := ExpireSubscriptionOrder("sub-expire-guard", PaymentMethodCreem)
|
||||
err := ExpireSubscriptionOrder("sub-expire-guard", PaymentProviderCreem)
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
|
||||
order := GetSubscriptionOrderByTradeNo("sub-expire-guard")
|
||||
|
||||
+15
-9
@@ -198,11 +198,12 @@ type SubscriptionOrder struct {
|
||||
PlanId int `json:"plan_id" gorm:"index"`
|
||||
Money float64 `json:"money"`
|
||||
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
Status string `json:"status"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
|
||||
Status string `json:"status"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
|
||||
ProviderPayload string `json:"provider_payload" gorm:"type:text"`
|
||||
}
|
||||
@@ -505,7 +506,9 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio
|
||||
}
|
||||
|
||||
// Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan.
|
||||
func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentMethod string) error {
|
||||
// expectedPaymentProvider guards against cross-gateway callback attacks (empty skips the check).
|
||||
// actualPaymentMethod updates the order's PaymentMethod to reflect the real payment type used (empty skips update).
|
||||
func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentProvider string, actualPaymentMethod string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("tradeNo is empty")
|
||||
}
|
||||
@@ -523,7 +526,7 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
|
||||
return ErrSubscriptionOrderNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if order.Status == common.TopUpStatusSuccess {
|
||||
@@ -552,6 +555,9 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
|
||||
if providerPayload != "" {
|
||||
order.ProviderPayload = providerPayload
|
||||
}
|
||||
if actualPaymentMethod != "" && order.PaymentMethod != actualPaymentMethod {
|
||||
order.PaymentMethod = actualPaymentMethod
|
||||
}
|
||||
if err := tx.Save(&order).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -610,7 +616,7 @@ func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
|
||||
return tx.Save(&topup).Error
|
||||
}
|
||||
|
||||
func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error {
|
||||
func ExpireSubscriptionOrder(tradeNo string, expectedPaymentProvider string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("tradeNo is empty")
|
||||
}
|
||||
@@ -623,7 +629,7 @@ func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
|
||||
return ErrSubscriptionOrderNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if order.Status != common.TopUpStatusPending {
|
||||
|
||||
+25
-16
@@ -12,15 +12,16 @@ import (
|
||||
)
|
||||
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -30,6 +31,14 @@ const (
|
||||
PaymentMethodWaffoPancake = "waffo_pancake"
|
||||
)
|
||||
|
||||
const (
|
||||
PaymentProviderEpay = "epay"
|
||||
PaymentProviderStripe = "stripe"
|
||||
PaymentProviderCreem = "creem"
|
||||
PaymentProviderWaffo = "waffo"
|
||||
PaymentProviderWaffoPancake = "waffo_pancake"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPaymentMethodMismatch = errors.New("payment method mismatch")
|
||||
ErrTopUpNotFound = errors.New("topup not found")
|
||||
@@ -68,7 +77,7 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
|
||||
return topUp
|
||||
}
|
||||
|
||||
func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targetStatus string) error {
|
||||
func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentProvider string, targetStatus string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("未提供支付单号")
|
||||
}
|
||||
@@ -83,7 +92,7 @@ func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targ
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error; err != nil {
|
||||
return ErrTopUpNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && topUp.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && topUp.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
@@ -114,7 +123,7 @@ func Recharge(referenceId string, customerId string, callerIp string) (err error
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodStripe {
|
||||
if topUp.PaymentProvider != PaymentProviderStripe {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -340,7 +349,7 @@ func ManualCompleteTopUp(tradeNo string, callerIp string) error {
|
||||
// 计算应充值额度:
|
||||
// - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit
|
||||
// - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit
|
||||
if topUp.PaymentMethod == PaymentMethodStripe {
|
||||
if topUp.PaymentProvider == PaymentProviderStripe {
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart())
|
||||
} else {
|
||||
@@ -397,7 +406,7 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodCreem {
|
||||
if topUp.PaymentProvider != PaymentProviderCreem {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -472,7 +481,7 @@ func RechargeWaffo(tradeNo string, callerIp string) (err error) {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodWaffo {
|
||||
if topUp.PaymentProvider != PaymentProviderWaffo {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -535,7 +544,7 @@ func RechargeWaffoPancake(tradeNo string) (err error) {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodWaffoPancake {
|
||||
if topUp.PaymentProvider != PaymentProviderWaffoPancake {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user