Files
new-api/controller/topup_stripe.go
T
2026-04-18 14:22:54 +08:00

425 lines
16 KiB
Go

package controller
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v81"
"github.com/stripe/stripe-go/v81/checkout/session"
"github.com/stripe/stripe-go/v81/webhook"
"github.com/thanhpk/randstr"
)
var stripeAdaptor = &StripeAdaptor{}
// StripePayRequest represents a payment request for Stripe checkout.
type StripePayRequest struct {
// Amount is the quantity of units to purchase.
Amount int64 `json:"amount"`
// PaymentMethod specifies the payment method (e.g., "stripe").
PaymentMethod string `json:"payment_method"`
// SuccessURL is the optional custom URL to redirect after successful payment.
// If empty, defaults to the server's console log page.
SuccessURL string `json:"success_url,omitempty"`
// CancelURL is the optional custom URL to redirect when payment is canceled.
// If empty, defaults to the server's console topup page.
CancelURL string `json:"cancel_url,omitempty"`
}
type StripeAdaptor struct {
}
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
if req.Amount < getStripeMinTopup() {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getStripePayMoney(float64(req.Amount), group)
if payMoney <= 0.01 {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
if req.PaymentMethod != model.PaymentMethodStripe {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.Amount < getStripeMinTopup() {
c.JSON(http.StatusOK, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
return
}
if req.Amount > 10000 {
c.JSON(http.StatusOK, gin.H{"message": "充值数量不能大于 10000", "data": 10})
return
}
if req.SuccessURL != "" && common.ValidateRedirectURL(req.SuccessURL) != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "支付成功重定向URL不在可信任域名列表中", "data": ""})
return
}
if req.CancelURL != "" && common.ValidateRedirectURL(req.CancelURL) != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "支付取消重定向URL不在可信任域名列表中", "data": ""})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
referenceId := "ref_" + common.Sha1([]byte(reference))
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount, req.SuccessURL, req.CancelURL)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 创建 Checkout Session 失败 user_id=%d trade_no=%s amount=%d error=%q", id, referenceId, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
topUp := &model.TopUp{
UserId: id,
Amount: req.Amount,
Money: chargedMoney,
TradeNo: referenceId,
PaymentMethod: model.PaymentMethodStripe,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
err = topUp.Insert()
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, referenceId, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Stripe 充值订单创建成功 user_id=%d trade_no=%s amount=%d money=%.2f", id, referenceId, req.Amount, chargedMoney))
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"pay_link": payLink,
},
})
}
func RequestStripeAmount(c *gin.Context) {
var req StripePayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
stripeAdaptor.RequestAmount(c, &req)
}
func RequestStripePay(c *gin.Context) {
var req StripePayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
stripeAdaptor.RequestPay(c, &req)
}
func StripeWebhook(c *gin.Context) {
ctx := c.Request.Context()
if !isStripeWebhookEnabled() {
logger.LogWarn(ctx, fmt.Sprintf("Stripe webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
payload, err := io.ReadAll(c.Request.Body)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("Stripe webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusServiceUnavailable)
return
}
signature := c.GetHeader("Stripe-Signature")
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(payload)))
event, err := webhook.ConstructEventWithOptions(payload, signature, setting.StripeWebhookSecret, webhook.ConstructEventOptions{
IgnoreAPIVersionMismatch: true,
})
if err != nil {
logger.LogWarn(ctx, fmt.Sprintf("Stripe webhook 验签失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusBadRequest)
return
}
callerIp := c.ClientIP()
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 验签成功 event_type=%s client_ip=%s path=%q", string(event.Type), callerIp, c.Request.RequestURI))
switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(ctx, event, callerIp)
case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(ctx, event)
case stripe.EventTypeCheckoutSessionAsyncPaymentSucceeded:
sessionAsyncPaymentSucceeded(ctx, event, callerIp)
case stripe.EventTypeCheckoutSessionAsyncPaymentFailed:
sessionAsyncPaymentFailed(ctx, event, callerIp)
default:
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 忽略事件 event_type=%s client_ip=%s", string(event.Type), callerIp))
}
c.Status(http.StatusOK)
}
func sessionCompleted(ctx context.Context, event stripe.Event, callerIp string) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "complete" != status {
logger.LogWarn(ctx, fmt.Sprintf("Stripe checkout.completed 状态异常,忽略处理 trade_no=%s status=%s client_ip=%s", referenceId, status, callerIp))
return
}
paymentStatus := event.GetObjectValue("payment_status")
if paymentStatus != "paid" {
logger.LogInfo(ctx, fmt.Sprintf("Stripe Checkout 支付未完成,等待异步结果 trade_no=%s payment_status=%s client_ip=%s", referenceId, paymentStatus, callerIp))
return
}
fulfillOrder(ctx, event, referenceId, customerId, callerIp)
}
// sessionAsyncPaymentSucceeded handles delayed payment methods (bank transfer, SEPA, etc.)
// that confirm payment after the checkout session completes.
func sessionAsyncPaymentSucceeded(ctx context.Context, event stripe.Event, callerIp string) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
logger.LogInfo(ctx, fmt.Sprintf("Stripe 异步支付成功 trade_no=%s client_ip=%s", referenceId, callerIp))
fulfillOrder(ctx, event, referenceId, customerId, callerIp)
}
// sessionAsyncPaymentFailed marks orders as failed when delayed payment methods
// ultimately fail (e.g. bank transfer not received, SEPA rejected).
func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp string) {
referenceId := event.GetObjectValue("client_reference_id")
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败 trade_no=%s client_ip=%s", referenceId, callerIp))
if len(referenceId) == 0 {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败事件缺少订单号 client_ip=%s", callerIp))
return
}
LockOrder(referenceId)
defer UnlockOrder(referenceId)
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但本地订单不存在 trade_no=%s client_ip=%s", referenceId, callerIp))
return
}
if topUp.PaymentMethod != model.PaymentMethodStripe {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付方式不匹配 trade_no=%s payment_method=%s client_ip=%s", referenceId, topUp.PaymentMethod, callerIp))
return
}
if topUp.Status != common.TopUpStatusPending {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 异步支付失败但订单状态非 pending,忽略处理 trade_no=%s status=%s client_ip=%s", referenceId, topUp.Status, callerIp))
return
}
topUp.Status = common.TopUpStatusFailed
if err := topUp.Update(); err != nil {
logger.LogError(ctx, fmt.Sprintf("Stripe 标记充值订单失败状态失败 trade_no=%s client_ip=%s error=%q", referenceId, callerIp, err.Error()))
return
}
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值订单已标记为失败 trade_no=%s client_ip=%s", referenceId, callerIp))
}
// fulfillOrder is the shared logic for crediting quota after payment is confirmed.
func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, customerId string, callerIp string) {
if len(referenceId) == 0 {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 完成订单时缺少订单号 client_ip=%s", callerIp))
return
}
LockOrder(referenceId)
defer UnlockOrder(referenceId)
payload := map[string]any{
"customer": customerId,
"amount_total": event.GetObjectValue("amount_total"),
"currency": strings.ToUpper(event.GetObjectValue("currency")),
"event_type": string(event.Type),
}
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentMethodStripe); err == nil {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp))
return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
logger.LogError(ctx, fmt.Sprintf("Stripe 订阅订单处理失败 trade_no=%s event_type=%s client_ip=%s error=%q", referenceId, string(event.Type), callerIp, err.Error()))
return
}
err := model.Recharge(referenceId, customerId, callerIp)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("Stripe 充值处理失败 trade_no=%s event_type=%s client_ip=%s error=%q", referenceId, string(event.Type), callerIp, err.Error()))
return
}
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency"))
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值成功 trade_no=%s amount_total=%.2f currency=%s event_type=%s client_ip=%s", referenceId, total/100, currency, string(event.Type), callerIp))
}
func sessionExpired(ctx context.Context, event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "expired" != status {
logger.LogWarn(ctx, fmt.Sprintf("Stripe checkout.expired 状态异常,忽略处理 trade_no=%s status=%s", referenceId, status))
return
}
if len(referenceId) == 0 {
logger.LogWarn(ctx, "Stripe checkout.expired 缺少订单号")
return
}
// Subscription order expiration
LockOrder(referenceId)
defer UnlockOrder(referenceId)
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentMethodStripe); err == nil {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId))
return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
logger.LogError(ctx, fmt.Sprintf("Stripe 订阅订单过期处理失败 trade_no=%s error=%q", referenceId, err.Error()))
return
}
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentMethodStripe, common.TopUpStatusExpired)
if errors.Is(err, model.ErrTopUpNotFound) {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId))
return
}
if err != nil {
logger.LogError(ctx, fmt.Sprintf("Stripe 充值订单过期处理失败 trade_no=%s error=%q", referenceId, err.Error()))
return
}
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值订单已过期 trade_no=%s", referenceId))
}
// genStripeLink generates a Stripe Checkout session URL for payment.
// It creates a new checkout session with the specified parameters and returns the payment URL.
//
// Parameters:
// - referenceId: unique reference identifier for the transaction
// - customerId: existing Stripe customer ID (empty string if new customer)
// - email: customer email address for new customer creation
// - amount: quantity of units to purchase
// - successURL: custom URL to redirect after successful payment (empty for default)
// - cancelURL: custom URL to redirect when payment is canceled (empty for default)
//
// Returns the checkout session URL or an error if the session creation fails.
func genStripeLink(referenceId string, customerId string, email string, amount int64, successURL string, cancelURL string) (string, error) {
if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
return "", fmt.Errorf("无效的Stripe API密钥")
}
stripe.Key = setting.StripeApiSecret
// Use custom URLs if provided, otherwise use defaults
if successURL == "" {
successURL = system_setting.ServerAddress + "/console/log"
}
if cancelURL == "" {
cancelURL = system_setting.ServerAddress + "/console/topup"
}
params := &stripe.CheckoutSessionParams{
ClientReferenceID: stripe.String(referenceId),
SuccessURL: stripe.String(successURL),
CancelURL: stripe.String(cancelURL),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(setting.StripePriceId),
Quantity: stripe.Int64(amount),
},
},
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
AllowPromotionCodes: stripe.Bool(setting.StripePromotionCodesEnabled),
}
if "" == customerId {
if "" != email {
params.CustomerEmail = stripe.String(email)
}
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
} else {
params.Customer = stripe.String(customerId)
}
result, err := session.New(params)
if err != nil {
return "", err
}
return result.URL, nil
}
func GetChargedAmount(count float64, user model.User) float64 {
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
if topUpGroupRatio == 0 {
topUpGroupRatio = 1
}
return count * topUpGroupRatio
}
func getStripePayMoney(amount float64, group string) float64 {
originalAmount := amount
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
amount = amount / common.QuotaPerUnit
}
// Using float64 for monetary calculations is acceptable here due to the small amounts involved
topupGroupRatio := common.GetTopupGroupRatio(group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
// apply optional preset discount by the original request amount (if configured), default 1.0
discount := 1.0
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
if ds > 0 {
discount = ds
}
}
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount
return payMoney
}
func getStripeMinTopup() int64 {
minTopup := setting.StripeMinTopUp
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
minTopup = minTopup * int(common.QuotaPerUnit)
}
return int64(minTopup)
}