From a7c38ec851a3fbd123dac5b3185f2fed6c179944 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 24 Apr 2026 22:16:16 +0800 Subject: [PATCH] fix: add PaymentProvider field to prevent cross-gateway callback attacks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit EPay allows users to switch payment methods (e.g. wxpay→alipay) during checkout, causing callback rejection. Replace fragile blocklist guard with a PaymentProvider field on TopUp and SubscriptionOrder that identifies which gateway created the order. --- controller/subscription_payment_creem.go | 15 ++-- controller/subscription_payment_epay.go | 21 +++--- controller/subscription_payment_stripe.go | 15 ++-- controller/topup.go | 38 ++++------ controller/topup_creem.go | 17 ++--- controller/topup_epay_guard_test.go | 31 --------- controller/topup_stripe.go | 25 +++---- controller/topup_waffo.go | 17 ++--- controller/topup_waffo_pancake.go | 15 ++-- model/payment_method_guard_test.go | 84 ++++++++++++----------- model/subscription.go | 24 ++++--- model/topup.go | 41 ++++++----- 12 files changed, 163 insertions(+), 180 deletions(-) delete mode 100644 controller/topup_epay_guard_test.go diff --git a/controller/subscription_payment_creem.go b/controller/subscription_payment_creem.go index 935429ac..18e1a584 100644 --- a/controller/subscription_payment_creem.go +++ b/controller/subscription_payment_creem.go @@ -83,13 +83,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) { // create pending order first order := &model.SubscriptionOrder{ - UserId: userId, - PlanId: plan.Id, - Money: plan.PriceAmount, - TradeNo: referenceId, - PaymentMethod: model.PaymentMethodCreem, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: userId, + PlanId: plan.Id, + Money: plan.PriceAmount, + TradeNo: referenceId, + PaymentMethod: model.PaymentMethodCreem, + PaymentProvider: model.PaymentProviderCreem, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } if err := order.Insert(); err != nil { c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"}) diff --git a/controller/subscription_payment_epay.go b/controller/subscription_payment_epay.go index 8f7848d5..2567654f 100644 --- a/controller/subscription_payment_epay.go +++ b/controller/subscription_payment_epay.go @@ -82,13 +82,14 @@ func SubscriptionRequestEpay(c *gin.Context) { } order := &model.SubscriptionOrder{ - UserId: userId, - PlanId: plan.Id, - Money: plan.PriceAmount, - TradeNo: tradeNo, - PaymentMethod: req.PaymentMethod, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: userId, + PlanId: plan.Id, + Money: plan.PriceAmount, + TradeNo: tradeNo, + PaymentMethod: req.PaymentMethod, + PaymentProvider: model.PaymentProviderEpay, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } if err := order.Insert(); err != nil { common.ApiErrorMsg(c, "创建订单失败") @@ -104,7 +105,7 @@ func SubscriptionRequestEpay(c *gin.Context) { ReturnUrl: returnUrl, }) if err != nil { - _ = model.ExpireSubscriptionOrder(tradeNo, req.PaymentMethod) + _ = model.ExpireSubscriptionOrder(tradeNo, model.PaymentProviderEpay) common.ApiErrorMsg(c, "拉起支付失败") return } @@ -156,7 +157,7 @@ func SubscriptionEpayNotify(c *gin.Context) { LockOrder(verifyInfo.ServiceTradeNo) defer UnlockOrder(verifyInfo.ServiceTradeNo) - if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil { + if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil { _, _ = c.Writer.Write([]byte("fail")) return } @@ -205,7 +206,7 @@ func SubscriptionEpayReturn(c *gin.Context) { if verifyInfo.TradeStatus == epay.StatusTradeSuccess { LockOrder(verifyInfo.ServiceTradeNo) defer UnlockOrder(verifyInfo.ServiceTradeNo) - if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil { + if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil { c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail") return } diff --git a/controller/subscription_payment_stripe.go b/controller/subscription_payment_stripe.go index 9824c90d..a5ce4685 100644 --- a/controller/subscription_payment_stripe.go +++ b/controller/subscription_payment_stripe.go @@ -84,13 +84,14 @@ func SubscriptionRequestStripePay(c *gin.Context) { } order := &model.SubscriptionOrder{ - UserId: userId, - PlanId: plan.Id, - Money: plan.PriceAmount, - TradeNo: referenceId, - PaymentMethod: model.PaymentMethodStripe, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: userId, + PlanId: plan.Id, + Money: plan.PriceAmount, + TradeNo: referenceId, + PaymentMethod: model.PaymentMethodStripe, + PaymentProvider: model.PaymentProviderStripe, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } if err := order.Insert(); err != nil { c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"}) diff --git a/controller/topup.go b/controller/topup.go index 86d361a3..a6445b40 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -123,17 +123,6 @@ type AmountRequest struct { Amount int64 `json:"amount"` } -var nonEpayPaymentMethodsForCallback = []string{ - model.PaymentMethodStripe, - model.PaymentMethodCreem, - model.PaymentMethodWaffo, - model.PaymentMethodWaffoPancake, -} - -func isNonEpayPaymentMethodForEpayCallback(paymentMethod string) bool { - return lo.Contains(nonEpayPaymentMethodsForCallback, paymentMethod) -} - func GetEpayClient() *epay.Client { if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" { return nil @@ -248,13 +237,14 @@ func RequestEpay(c *gin.Context) { amount = dAmount.Div(dQuotaPerUnit).IntPart() } topUp := &model.TopUp{ - UserId: id, - Amount: amount, - Money: payMoney, - TradeNo: tradeNo, - PaymentMethod: req.PaymentMethod, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: id, + Amount: amount, + Money: payMoney, + TradeNo: tradeNo, + PaymentMethod: req.PaymentMethod, + PaymentProvider: model.PaymentProviderEpay, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } err = topUp.Insert() if err != nil { @@ -379,15 +369,15 @@ func EpayNotify(c *gin.Context) { logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo))) return } - if isNonEpayPaymentMethodForEpayCallback(topUp.PaymentMethod) { - logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP())) - return - } - if topUp.PaymentMethod != verifyInfo.Type { - logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP())) + if topUp.PaymentProvider != model.PaymentProviderEpay { + logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付网关不匹配 trade_no=%s order_provider=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentProvider, verifyInfo.Type, c.ClientIP())) return } if topUp.Status == common.TopUpStatusPending { + if topUp.PaymentMethod != verifyInfo.Type { + logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 实际支付方式与订单不同 trade_no=%s order_payment_method=%s actual_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP())) + topUp.PaymentMethod = verifyInfo.Type + } topUp.Status = common.TopUpStatusSuccess err := topUp.Update() if err != nil { diff --git a/controller/topup_creem.go b/controller/topup_creem.go index 139dd43f..7472690e 100644 --- a/controller/topup_creem.go +++ b/controller/topup_creem.go @@ -106,13 +106,14 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) { // 先创建订单记录,使用产品配置的金额和充值额度 topUp := &model.TopUp{ - UserId: id, - Amount: selectedProduct.Quota, // 充值额度 - Money: selectedProduct.Price, // 支付金额 - TradeNo: referenceId, - PaymentMethod: model.PaymentMethodCreem, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: id, + Amount: selectedProduct.Quota, // 充值额度 + Money: selectedProduct.Price, // 支付金额 + TradeNo: referenceId, + PaymentMethod: model.PaymentMethodCreem, + PaymentProvider: model.PaymentProviderCreem, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } err = topUp.Insert() if err != nil { @@ -301,7 +302,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) { // Try complete subscription order first LockOrder(referenceId) defer UnlockOrder(referenceId) - if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentMethodCreem); err == nil { + if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentProviderCreem, ""); err == nil { logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id)) c.Status(http.StatusOK) return diff --git a/controller/topup_epay_guard_test.go b/controller/topup_epay_guard_test.go deleted file mode 100644 index 34512665..00000000 --- a/controller/topup_epay_guard_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package controller - -import ( - "testing" - - "github.com/QuantumNous/new-api/model" -) - -func TestIsNonEpayPaymentMethodForEpayCallback(t *testing.T) { - testCases := []struct { - name string - paymentMethod string - expectedBlocked bool - }{ - {name: "stripe", paymentMethod: model.PaymentMethodStripe, expectedBlocked: true}, - {name: "creem", paymentMethod: model.PaymentMethodCreem, expectedBlocked: true}, - {name: "waffo", paymentMethod: model.PaymentMethodWaffo, expectedBlocked: true}, - {name: "waffo pancake", paymentMethod: model.PaymentMethodWaffoPancake, expectedBlocked: true}, - {name: "alipay", paymentMethod: "alipay", expectedBlocked: false}, - {name: "wxpay", paymentMethod: "wxpay", expectedBlocked: false}, - {name: "custom epay type", paymentMethod: "custom1", expectedBlocked: false}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if actual := isNonEpayPaymentMethodForEpayCallback(tc.paymentMethod); actual != tc.expectedBlocked { - t.Fatalf("expected blocked=%v, got %v for payment method %q", tc.expectedBlocked, actual, tc.paymentMethod) - } - }) - } -} diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index 23ddb3b9..ceee8ecd 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -101,13 +101,14 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) { } topUp := &model.TopUp{ - UserId: id, - Amount: req.Amount, - Money: chargedMoney, - TradeNo: referenceId, - PaymentMethod: model.PaymentMethodStripe, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: id, + Amount: req.Amount, + Money: chargedMoney, + TradeNo: referenceId, + PaymentMethod: model.PaymentMethodStripe, + PaymentProvider: model.PaymentProviderStripe, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } err = topUp.Insert() if err != nil { @@ -237,8 +238,8 @@ func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp return } - if topUp.PaymentMethod != model.PaymentMethodStripe { - logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付方式不匹配 trade_no=%s payment_method=%s client_ip=%s", referenceId, topUp.PaymentMethod, callerIp)) + if topUp.PaymentProvider != model.PaymentProviderStripe { + logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付网关不匹配 trade_no=%s payment_provider=%s client_ip=%s", referenceId, topUp.PaymentProvider, callerIp)) return } @@ -270,7 +271,7 @@ func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, c "currency": strings.ToUpper(event.GetObjectValue("currency")), "event_type": string(event.Type), } - if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentMethodStripe); err == nil { + if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentProviderStripe, ""); err == nil { logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp)) return } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { @@ -305,7 +306,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) { // Subscription order expiration LockOrder(referenceId) defer UnlockOrder(referenceId) - if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentMethodStripe); err == nil { + if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentProviderStripe); err == nil { logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId)) return } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { @@ -313,7 +314,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) { return } - err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentMethodStripe, common.TopUpStatusExpired) + err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentProviderStripe, common.TopUpStatusExpired) if errors.Is(err, model.ErrTopUpNotFound) { logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId)) return diff --git a/controller/topup_waffo.go b/controller/topup_waffo.go index c0068062..1885c1de 100644 --- a/controller/topup_waffo.go +++ b/controller/topup_waffo.go @@ -208,13 +208,14 @@ func RequestWaffoPay(c *gin.Context) { // 创建本地订单 topUp := &model.TopUp{ - UserId: id, - Amount: amount, - Money: payMoney, - TradeNo: merchantOrderId, - PaymentMethod: model.PaymentMethodWaffo, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: id, + Amount: amount, + Money: payMoney, + TradeNo: merchantOrderId, + PaymentMethod: model.PaymentMethodWaffo, + PaymentProvider: model.PaymentProviderWaffo, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } if err := topUp.Insert(); err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error())) @@ -379,7 +380,7 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP())) // 终态失败订单标记为 failed,避免永远停在 pending if result.MerchantOrderID != "" { - if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentMethodWaffo, common.TopUpStatusFailed); err != nil && + if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentProviderWaffo, common.TopUpStatusFailed); err != nil && !errors.Is(err, model.ErrTopUpNotFound) && !errors.Is(err, model.ErrTopUpStatusInvalid) { logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error())) diff --git a/controller/topup_waffo_pancake.go b/controller/topup_waffo_pancake.go index 81515a56..09f15163 100644 --- a/controller/topup_waffo_pancake.go +++ b/controller/topup_waffo_pancake.go @@ -159,13 +159,14 @@ func RequestWaffoPancakePay(c *gin.Context) { tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6)) topUp := &model.TopUp{ - UserId: id, - Amount: normalizeWaffoPancakeTopUpAmount(req.Amount), - Money: payMoney, - TradeNo: tradeNo, - PaymentMethod: model.PaymentMethodWaffoPancake, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: id, + Amount: normalizeWaffoPancakeTopUpAmount(req.Amount), + Money: payMoney, + TradeNo: tradeNo, + PaymentMethod: model.PaymentMethodWaffoPancake, + PaymentProvider: model.PaymentProviderWaffoPancake, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } if err := topUp.Insert(); err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error())) diff --git a/model/payment_method_guard_test.go b/model/payment_method_guard_test.go index 9bc29244..7f4f15cc 100644 --- a/model/payment_method_guard_test.go +++ b/model/payment_method_guard_test.go @@ -36,30 +36,32 @@ func insertSubscriptionPlanForPaymentGuardTest(t *testing.T, id int) *Subscripti return plan } -func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentMethod string) { +func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentProvider string) { t.Helper() order := &SubscriptionOrder{ - UserId: userID, - PlanId: planID, - Money: 9.99, - TradeNo: tradeNo, - PaymentMethod: paymentMethod, - Status: common.TopUpStatusPending, - CreateTime: time.Now().Unix(), + UserId: userID, + PlanId: planID, + Money: 9.99, + TradeNo: tradeNo, + PaymentMethod: paymentProvider, + PaymentProvider: paymentProvider, + Status: common.TopUpStatusPending, + CreateTime: time.Now().Unix(), } require.NoError(t, order.Insert()) } -func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentMethod string) { +func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentProvider string) { t.Helper() topUp := &TopUp{ - UserId: userID, - Amount: 2, - Money: 9.99, - TradeNo: tradeNo, - PaymentMethod: paymentMethod, - Status: common.TopUpStatusPending, - CreateTime: time.Now().Unix(), + UserId: userID, + Amount: 2, + Money: 9.99, + TradeNo: tradeNo, + PaymentMethod: paymentProvider, + PaymentProvider: paymentProvider, + Status: common.TopUpStatusPending, + CreateTime: time.Now().Unix(), } require.NoError(t, topUp.Insert()) } @@ -89,7 +91,7 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) { truncateTables(t) insertUserForPaymentGuardTest(t, 101, 0) - insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentMethodStripe) + insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentProviderStripe) err := RechargeWaffoPancake("waffo-pancake-guard") require.Error(t, err) @@ -100,27 +102,27 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) { assert.Equal(t, 0, getUserQuotaForPaymentGuardTest(t, 101)) } -func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) { +func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentProvider(t *testing.T) { testCases := []struct { - name string - tradeNo string - storedPaymentMethod string - expectedPaymentMethod string - targetStatus string + name string + tradeNo string + storedPaymentProvider string + expectedPaymentProvider string + targetStatus string }{ { - name: "stripe expire", - tradeNo: "stripe-expire-guard", - storedPaymentMethod: PaymentMethodCreem, - expectedPaymentMethod: PaymentMethodStripe, - targetStatus: common.TopUpStatusExpired, + name: "stripe expire", + tradeNo: "stripe-expire-guard", + storedPaymentProvider: PaymentProviderCreem, + expectedPaymentProvider: PaymentProviderStripe, + targetStatus: common.TopUpStatusExpired, }, { - name: "waffo failed", - tradeNo: "waffo-failed-guard", - storedPaymentMethod: PaymentMethodStripe, - expectedPaymentMethod: PaymentMethodWaffo, - targetStatus: common.TopUpStatusFailed, + name: "waffo failed", + tradeNo: "waffo-failed-guard", + storedPaymentProvider: PaymentProviderStripe, + expectedPaymentProvider: PaymentProviderWaffo, + targetStatus: common.TopUpStatusFailed, }, } @@ -128,23 +130,23 @@ func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) { t.Run(tc.name, func(t *testing.T) { truncateTables(t) insertUserForPaymentGuardTest(t, 150, 0) - insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentMethod) + insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentProvider) - err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentMethod, tc.targetStatus) + err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentProvider, tc.targetStatus) require.ErrorIs(t, err, ErrPaymentMethodMismatch) assert.Equal(t, common.TopUpStatusPending, getTopUpStatusForPaymentGuardTest(t, tc.tradeNo)) }) } } -func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) { +func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) { truncateTables(t) insertUserForPaymentGuardTest(t, 202, 0) plan := insertSubscriptionPlanForPaymentGuardTest(t, 301) - insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentMethodStripe) + insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentProviderStripe) - err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, "alipay") + err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, PaymentProviderEpay, "alipay") require.ErrorIs(t, err, ErrPaymentMethodMismatch) order := GetSubscriptionOrderByTradeNo("sub-guard-order") @@ -156,14 +158,14 @@ func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) assert.Nil(t, topUp) } -func TestExpireSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) { +func TestExpireSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) { truncateTables(t) insertUserForPaymentGuardTest(t, 303, 0) plan := insertSubscriptionPlanForPaymentGuardTest(t, 401) - insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentMethodStripe) + insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentProviderStripe) - err := ExpireSubscriptionOrder("sub-expire-guard", PaymentMethodCreem) + err := ExpireSubscriptionOrder("sub-expire-guard", PaymentProviderCreem) require.ErrorIs(t, err, ErrPaymentMethodMismatch) order := GetSubscriptionOrderByTradeNo("sub-expire-guard") diff --git a/model/subscription.go b/model/subscription.go index 10e750c3..da8fdae9 100644 --- a/model/subscription.go +++ b/model/subscription.go @@ -198,11 +198,12 @@ type SubscriptionOrder struct { PlanId int `json:"plan_id" gorm:"index"` Money float64 `json:"money"` - TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` - PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` - Status string `json:"status"` - CreateTime int64 `json:"create_time"` - CompleteTime int64 `json:"complete_time"` + TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` + PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` + PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"` + Status string `json:"status"` + CreateTime int64 `json:"create_time"` + CompleteTime int64 `json:"complete_time"` ProviderPayload string `json:"provider_payload" gorm:"type:text"` } @@ -505,7 +506,9 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio } // Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan. -func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentMethod string) error { +// expectedPaymentProvider guards against cross-gateway callback attacks (empty skips the check). +// actualPaymentMethod updates the order's PaymentMethod to reflect the real payment type used (empty skips update). +func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentProvider string, actualPaymentMethod string) error { if tradeNo == "" { return errors.New("tradeNo is empty") } @@ -523,7 +526,7 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil { return ErrSubscriptionOrderNotFound } - if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod { + if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider { return ErrPaymentMethodMismatch } if order.Status == common.TopUpStatusSuccess { @@ -552,6 +555,9 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP if providerPayload != "" { order.ProviderPayload = providerPayload } + if actualPaymentMethod != "" && order.PaymentMethod != actualPaymentMethod { + order.PaymentMethod = actualPaymentMethod + } if err := tx.Save(&order).Error; err != nil { return err } @@ -610,7 +616,7 @@ func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error { return tx.Save(&topup).Error } -func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error { +func ExpireSubscriptionOrder(tradeNo string, expectedPaymentProvider string) error { if tradeNo == "" { return errors.New("tradeNo is empty") } @@ -623,7 +629,7 @@ func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil { return ErrSubscriptionOrderNotFound } - if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod { + if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider { return ErrPaymentMethodMismatch } if order.Status != common.TopUpStatusPending { diff --git a/model/topup.go b/model/topup.go index c1ac663f..c071b77b 100644 --- a/model/topup.go +++ b/model/topup.go @@ -12,15 +12,16 @@ import ( ) type TopUp struct { - Id int `json:"id"` - UserId int `json:"user_id" gorm:"index"` - Amount int64 `json:"amount"` - Money float64 `json:"money"` - TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` - PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` - CreateTime int64 `json:"create_time"` - CompleteTime int64 `json:"complete_time"` - Status string `json:"status"` + Id int `json:"id"` + UserId int `json:"user_id" gorm:"index"` + Amount int64 `json:"amount"` + Money float64 `json:"money"` + TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` + PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` + PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"` + CreateTime int64 `json:"create_time"` + CompleteTime int64 `json:"complete_time"` + Status string `json:"status"` } const ( @@ -30,6 +31,14 @@ const ( PaymentMethodWaffoPancake = "waffo_pancake" ) +const ( + PaymentProviderEpay = "epay" + PaymentProviderStripe = "stripe" + PaymentProviderCreem = "creem" + PaymentProviderWaffo = "waffo" + PaymentProviderWaffoPancake = "waffo_pancake" +) + var ( ErrPaymentMethodMismatch = errors.New("payment method mismatch") ErrTopUpNotFound = errors.New("topup not found") @@ -68,7 +77,7 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp { return topUp } -func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targetStatus string) error { +func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentProvider string, targetStatus string) error { if tradeNo == "" { return errors.New("未提供支付单号") } @@ -83,7 +92,7 @@ func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targ if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error; err != nil { return ErrTopUpNotFound } - if expectedPaymentMethod != "" && topUp.PaymentMethod != expectedPaymentMethod { + if expectedPaymentProvider != "" && topUp.PaymentProvider != expectedPaymentProvider { return ErrPaymentMethodMismatch } if topUp.Status != common.TopUpStatusPending { @@ -114,7 +123,7 @@ func Recharge(referenceId string, customerId string, callerIp string) (err error return errors.New("充值订单不存在") } - if topUp.PaymentMethod != PaymentMethodStripe { + if topUp.PaymentProvider != PaymentProviderStripe { return ErrPaymentMethodMismatch } @@ -340,7 +349,7 @@ func ManualCompleteTopUp(tradeNo string, callerIp string) error { // 计算应充值额度: // - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit // - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit - if topUp.PaymentMethod == PaymentMethodStripe { + if topUp.PaymentProvider == PaymentProviderStripe { dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart()) } else { @@ -397,7 +406,7 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string return errors.New("充值订单不存在") } - if topUp.PaymentMethod != PaymentMethodCreem { + if topUp.PaymentProvider != PaymentProviderCreem { return ErrPaymentMethodMismatch } @@ -472,7 +481,7 @@ func RechargeWaffo(tradeNo string, callerIp string) (err error) { return errors.New("充值订单不存在") } - if topUp.PaymentMethod != PaymentMethodWaffo { + if topUp.PaymentProvider != PaymentProviderWaffo { return ErrPaymentMethodMismatch } @@ -535,7 +544,7 @@ func RechargeWaffoPancake(tradeNo string) (err error) { return errors.New("充值订单不存在") } - if topUp.PaymentMethod != PaymentMethodWaffoPancake { + if topUp.PaymentProvider != PaymentProviderWaffoPancake { return ErrPaymentMethodMismatch }