723 lines
20 KiB
Go
723 lines
20 KiB
Go
package model
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/constant"
|
|
|
|
"github.com/glebarez/sqlite"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var commonGroupCol string
|
|
var commonKeyCol string
|
|
var commonTrueVal string
|
|
var commonFalseVal string
|
|
|
|
var logKeyCol string
|
|
var logGroupCol string
|
|
|
|
func initCol() {
|
|
// init common column names
|
|
if common.UsingPostgreSQL {
|
|
commonGroupCol = `"group"`
|
|
commonKeyCol = `"key"`
|
|
commonTrueVal = "true"
|
|
commonFalseVal = "false"
|
|
} else {
|
|
commonGroupCol = "`group`"
|
|
commonKeyCol = "`key`"
|
|
commonTrueVal = "1"
|
|
commonFalseVal = "0"
|
|
}
|
|
if os.Getenv("LOG_SQL_DSN") != "" {
|
|
switch common.LogSqlType {
|
|
case common.DatabaseTypePostgreSQL:
|
|
logGroupCol = `"group"`
|
|
logKeyCol = `"key"`
|
|
default:
|
|
logGroupCol = commonGroupCol
|
|
logKeyCol = commonKeyCol
|
|
}
|
|
} else {
|
|
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
|
|
if common.UsingPostgreSQL {
|
|
logGroupCol = `"group"`
|
|
logKeyCol = `"key"`
|
|
} else {
|
|
logGroupCol = commonGroupCol
|
|
logKeyCol = commonKeyCol
|
|
}
|
|
}
|
|
// log sql type and database type
|
|
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
|
}
|
|
|
|
var DB *gorm.DB
|
|
|
|
var LOG_DB *gorm.DB
|
|
|
|
func createRootAccountIfNeed() error {
|
|
var user User
|
|
//if user.Status != common.UserStatusEnabled {
|
|
if err := DB.First(&user).Error; err != nil {
|
|
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
|
hashedPassword, err := common.Password2Hash("123456")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rootUser := User{
|
|
Username: "root",
|
|
Password: hashedPassword,
|
|
Role: common.RoleRootUser,
|
|
Status: common.UserStatusEnabled,
|
|
DisplayName: "Root User",
|
|
AccessToken: nil,
|
|
Quota: 100000000,
|
|
}
|
|
DB.Create(&rootUser)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func CheckSetup() {
|
|
setup := GetSetup()
|
|
if setup == nil {
|
|
// No setup record exists, check if we have a root user
|
|
if RootUserExists() {
|
|
common.SysLog("system is not initialized, but root user exists")
|
|
// Create setup record
|
|
newSetup := Setup{
|
|
Version: common.Version,
|
|
InitializedAt: time.Now().Unix(),
|
|
}
|
|
err := DB.Create(&newSetup).Error
|
|
if err != nil {
|
|
common.SysLog("failed to create setup record: " + err.Error())
|
|
}
|
|
constant.Setup = true
|
|
} else {
|
|
common.SysLog("system is not initialized and no root user exists")
|
|
constant.Setup = false
|
|
}
|
|
} else {
|
|
// Setup record exists, system is initialized
|
|
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
|
|
constant.Setup = true
|
|
}
|
|
}
|
|
|
|
func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
|
defer func() {
|
|
initCol()
|
|
}()
|
|
dsn := os.Getenv(envName)
|
|
if dsn != "" {
|
|
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
|
// Use PostgreSQL
|
|
common.SysLog("using PostgreSQL as database")
|
|
if !isLog {
|
|
common.UsingPostgreSQL = true
|
|
} else {
|
|
common.LogSqlType = common.DatabaseTypePostgreSQL
|
|
}
|
|
return gorm.Open(postgres.New(postgres.Config{
|
|
DSN: dsn,
|
|
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
|
}), &gorm.Config{
|
|
PrepareStmt: true, // precompile SQL
|
|
})
|
|
}
|
|
if strings.HasPrefix(dsn, "local") {
|
|
common.SysLog("SQL_DSN not set, using SQLite as database")
|
|
if !isLog {
|
|
common.UsingSQLite = true
|
|
} else {
|
|
common.LogSqlType = common.DatabaseTypeSQLite
|
|
}
|
|
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
|
PrepareStmt: true, // precompile SQL
|
|
})
|
|
}
|
|
// Use MySQL
|
|
common.SysLog("using MySQL as database")
|
|
// check parseTime
|
|
if !strings.Contains(dsn, "parseTime") {
|
|
if strings.Contains(dsn, "?") {
|
|
dsn += "&parseTime=true"
|
|
} else {
|
|
dsn += "?parseTime=true"
|
|
}
|
|
}
|
|
if !isLog {
|
|
common.UsingMySQL = true
|
|
} else {
|
|
common.LogSqlType = common.DatabaseTypeMySQL
|
|
}
|
|
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
|
PrepareStmt: true, // precompile SQL
|
|
})
|
|
}
|
|
// Use SQLite
|
|
common.SysLog("SQL_DSN not set, using SQLite as database")
|
|
common.UsingSQLite = true
|
|
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
|
PrepareStmt: true, // precompile SQL
|
|
})
|
|
}
|
|
|
|
func InitDB() (err error) {
|
|
db, err := chooseDB("SQL_DSN", false)
|
|
if err == nil {
|
|
if common.DebugEnabled {
|
|
db = db.Debug()
|
|
}
|
|
DB = db
|
|
// MySQL charset/collation startup check: ensure Chinese-capable charset
|
|
if common.UsingMySQL {
|
|
if err := checkMySQLChineseSupport(DB); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
sqlDB, err := DB.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
|
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
|
|
|
|
if !common.IsMasterNode {
|
|
return nil
|
|
}
|
|
if common.UsingMySQL {
|
|
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
|
}
|
|
common.SysLog("database migration started")
|
|
err = migrateDB()
|
|
return err
|
|
} else {
|
|
common.FatalLog(err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func InitLogDB() (err error) {
|
|
if os.Getenv("LOG_SQL_DSN") == "" {
|
|
LOG_DB = DB
|
|
return
|
|
}
|
|
db, err := chooseDB("LOG_SQL_DSN", true)
|
|
if err == nil {
|
|
if common.DebugEnabled {
|
|
db = db.Debug()
|
|
}
|
|
LOG_DB = db
|
|
// If log DB is MySQL, also ensure Chinese-capable charset
|
|
if common.LogSqlType == common.DatabaseTypeMySQL {
|
|
if err := checkMySQLChineseSupport(LOG_DB); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
sqlDB, err := LOG_DB.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
|
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
|
|
|
|
if !common.IsMasterNode {
|
|
return nil
|
|
}
|
|
common.SysLog("database migration started")
|
|
err = migrateLOGDB()
|
|
return err
|
|
} else {
|
|
common.FatalLog(err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func migrateDB() error {
|
|
// Migrate price_amount column from float/double to decimal for existing tables
|
|
migrateSubscriptionPlanPriceAmount()
|
|
// Migrate model_limits column from varchar to text for existing tables
|
|
if err := migrateTokenModelLimitsToText(); err != nil {
|
|
return err
|
|
}
|
|
|
|
err := DB.AutoMigrate(
|
|
&Channel{},
|
|
&Token{},
|
|
&User{},
|
|
&PasskeyCredential{},
|
|
&Option{},
|
|
&Redemption{},
|
|
&Ability{},
|
|
&Log{},
|
|
&Midjourney{},
|
|
&TopUp{},
|
|
&QuotaData{},
|
|
&Task{},
|
|
&Model{},
|
|
&Vendor{},
|
|
&PrefillGroup{},
|
|
&Setup{},
|
|
&TwoFA{},
|
|
&TwoFABackupCode{},
|
|
&Checkin{},
|
|
&SubscriptionOrder{},
|
|
&UserSubscription{},
|
|
&SubscriptionPreConsumeRecord{},
|
|
&CustomOAuthProvider{},
|
|
&UserOAuthBinding{},
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := migrateRedemptionRedeemType(); err != nil {
|
|
return err
|
|
}
|
|
if common.UsingSQLite {
|
|
if err := ensureSubscriptionPlanTableSQLite(); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := DB.AutoMigrate(&SubscriptionPlan{}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func migrateDBFast() error {
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
migrations := []struct {
|
|
model interface{}
|
|
name string
|
|
}{
|
|
{&Channel{}, "Channel"},
|
|
{&Token{}, "Token"},
|
|
{&User{}, "User"},
|
|
{&PasskeyCredential{}, "PasskeyCredential"},
|
|
{&Option{}, "Option"},
|
|
{&Redemption{}, "Redemption"},
|
|
{&Ability{}, "Ability"},
|
|
{&Log{}, "Log"},
|
|
{&Midjourney{}, "Midjourney"},
|
|
{&TopUp{}, "TopUp"},
|
|
{&QuotaData{}, "QuotaData"},
|
|
{&Task{}, "Task"},
|
|
{&Model{}, "Model"},
|
|
{&Vendor{}, "Vendor"},
|
|
{&PrefillGroup{}, "PrefillGroup"},
|
|
{&Setup{}, "Setup"},
|
|
{&TwoFA{}, "TwoFA"},
|
|
{&TwoFABackupCode{}, "TwoFABackupCode"},
|
|
{&Checkin{}, "Checkin"},
|
|
{&SubscriptionOrder{}, "SubscriptionOrder"},
|
|
{&UserSubscription{}, "UserSubscription"},
|
|
{&SubscriptionPreConsumeRecord{}, "SubscriptionPreConsumeRecord"},
|
|
{&CustomOAuthProvider{}, "CustomOAuthProvider"},
|
|
{&UserOAuthBinding{}, "UserOAuthBinding"},
|
|
}
|
|
// 动态计算migration数量,确保errChan缓冲区足够大
|
|
errChan := make(chan error, len(migrations))
|
|
|
|
for _, m := range migrations {
|
|
wg.Add(1)
|
|
go func(model interface{}, name string) {
|
|
defer wg.Done()
|
|
if err := DB.AutoMigrate(model); err != nil {
|
|
errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
|
|
}
|
|
}(m.model, m.name)
|
|
}
|
|
|
|
// Wait for all migrations to complete
|
|
wg.Wait()
|
|
close(errChan)
|
|
|
|
// Check for any errors
|
|
for err := range errChan {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if err := migrateRedemptionRedeemType(); err != nil {
|
|
return err
|
|
}
|
|
if common.UsingSQLite {
|
|
if err := ensureSubscriptionPlanTableSQLite(); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := DB.AutoMigrate(&SubscriptionPlan{}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
common.SysLog("database migrated")
|
|
return nil
|
|
}
|
|
|
|
func migrateRedemptionRedeemType() error {
|
|
if !DB.Migrator().HasTable(&Redemption{}) {
|
|
return nil
|
|
}
|
|
if !DB.Migrator().HasColumn(&Redemption{}, "redeem_type") {
|
|
return nil
|
|
}
|
|
return DB.Model(&Redemption{}).
|
|
Where("redeem_type = '' OR redeem_type IS NULL").
|
|
Update("redeem_type", RedemptionTypeQuota).Error
|
|
}
|
|
|
|
func migrateLOGDB() error {
|
|
var err error
|
|
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type sqliteColumnDef struct {
|
|
Name string
|
|
DDL string
|
|
}
|
|
|
|
func ensureSubscriptionPlanTableSQLite() error {
|
|
if !common.UsingSQLite {
|
|
return nil
|
|
}
|
|
tableName := "subscription_plans"
|
|
if !DB.Migrator().HasTable(tableName) {
|
|
createSQL := `CREATE TABLE ` + "`" + tableName + "`" + ` (
|
|
` + "`id`" + ` integer,
|
|
` + "`title`" + ` varchar(128) NOT NULL,
|
|
` + "`subtitle`" + ` varchar(255) DEFAULT '',
|
|
` + "`price_amount`" + ` decimal(10,6) NOT NULL,
|
|
` + "`currency`" + ` varchar(8) NOT NULL DEFAULT 'USD',
|
|
` + "`duration_unit`" + ` varchar(16) NOT NULL DEFAULT 'month',
|
|
` + "`duration_value`" + ` integer NOT NULL DEFAULT 1,
|
|
` + "`custom_seconds`" + ` bigint NOT NULL DEFAULT 0,
|
|
` + "`enabled`" + ` numeric DEFAULT 1,
|
|
` + "`sort_order`" + ` integer DEFAULT 0,
|
|
` + "`stripe_price_id`" + ` varchar(128) DEFAULT '',
|
|
` + "`creem_product_id`" + ` varchar(128) DEFAULT '',
|
|
` + "`max_purchase_per_user`" + ` integer DEFAULT 0,
|
|
` + "`upgrade_group`" + ` varchar(64) DEFAULT '',
|
|
` + "`total_amount`" + ` bigint NOT NULL DEFAULT 0,
|
|
` + "`quota_reset_period`" + ` varchar(16) DEFAULT 'never',
|
|
` + "`quota_reset_custom_seconds`" + ` bigint DEFAULT 0,
|
|
` + "`created_at`" + ` bigint,
|
|
` + "`updated_at`" + ` bigint,
|
|
PRIMARY KEY (` + "`id`" + `)
|
|
)`
|
|
return DB.Exec(createSQL).Error
|
|
}
|
|
var cols []struct {
|
|
Name string `gorm:"column:name"`
|
|
}
|
|
if err := DB.Raw("PRAGMA table_info(`" + tableName + "`)").Scan(&cols).Error; err != nil {
|
|
return err
|
|
}
|
|
existing := make(map[string]struct{}, len(cols))
|
|
for _, c := range cols {
|
|
existing[c.Name] = struct{}{}
|
|
}
|
|
required := []sqliteColumnDef{
|
|
{Name: "title", DDL: "`title` varchar(128) NOT NULL"},
|
|
{Name: "subtitle", DDL: "`subtitle` varchar(255) DEFAULT ''"},
|
|
{Name: "price_amount", DDL: "`price_amount` decimal(10,6) NOT NULL"},
|
|
{Name: "currency", DDL: "`currency` varchar(8) NOT NULL DEFAULT 'USD'"},
|
|
{Name: "duration_unit", DDL: "`duration_unit` varchar(16) NOT NULL DEFAULT 'month'"},
|
|
{Name: "duration_value", DDL: "`duration_value` integer NOT NULL DEFAULT 1"},
|
|
{Name: "custom_seconds", DDL: "`custom_seconds` bigint NOT NULL DEFAULT 0"},
|
|
{Name: "enabled", DDL: "`enabled` numeric DEFAULT 1"},
|
|
{Name: "sort_order", DDL: "`sort_order` integer DEFAULT 0"},
|
|
{Name: "stripe_price_id", DDL: "`stripe_price_id` varchar(128) DEFAULT ''"},
|
|
{Name: "creem_product_id", DDL: "`creem_product_id` varchar(128) DEFAULT ''"},
|
|
{Name: "max_purchase_per_user", DDL: "`max_purchase_per_user` integer DEFAULT 0"},
|
|
{Name: "upgrade_group", DDL: "`upgrade_group` varchar(64) DEFAULT ''"},
|
|
{Name: "total_amount", DDL: "`total_amount` bigint NOT NULL DEFAULT 0"},
|
|
{Name: "quota_reset_period", DDL: "`quota_reset_period` varchar(16) DEFAULT 'never'"},
|
|
{Name: "quota_reset_custom_seconds", DDL: "`quota_reset_custom_seconds` bigint DEFAULT 0"},
|
|
{Name: "created_at", DDL: "`created_at` bigint"},
|
|
{Name: "updated_at", DDL: "`updated_at` bigint"},
|
|
}
|
|
for _, col := range required {
|
|
if _, ok := existing[col.Name]; ok {
|
|
continue
|
|
}
|
|
if err := DB.Exec("ALTER TABLE `" + tableName + "` ADD COLUMN " + col.DDL).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// migrateTokenModelLimitsToText migrates model_limits column from varchar(1024) to text
|
|
// This is safe to run multiple times - it checks the column type first
|
|
func migrateTokenModelLimitsToText() error {
|
|
// SQLite uses type affinity, so TEXT and VARCHAR are effectively the same — no migration needed
|
|
if common.UsingSQLite {
|
|
return nil
|
|
}
|
|
|
|
tableName := "tokens"
|
|
columnName := "model_limits"
|
|
|
|
if !DB.Migrator().HasTable(tableName) {
|
|
return nil
|
|
}
|
|
|
|
if !DB.Migrator().HasColumn(&Token{}, columnName) {
|
|
return nil
|
|
}
|
|
|
|
var alterSQL string
|
|
if common.UsingPostgreSQL {
|
|
var dataType string
|
|
if err := DB.Raw(`SELECT data_type FROM information_schema.columns
|
|
WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
|
|
tableName, columnName).Scan(&dataType).Error; err != nil {
|
|
common.SysLog(fmt.Sprintf("Warning: failed to query metadata for %s.%s: %v", tableName, columnName, err))
|
|
} else if dataType == "text" {
|
|
return nil
|
|
}
|
|
alterSQL = fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s TYPE text`, tableName, columnName)
|
|
} else if common.UsingMySQL {
|
|
var columnType string
|
|
if err := DB.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
|
|
WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
|
|
tableName, columnName).Scan(&columnType).Error; err != nil {
|
|
common.SysLog(fmt.Sprintf("Warning: failed to query metadata for %s.%s: %v", tableName, columnName, err))
|
|
} else if strings.ToLower(columnType) == "text" {
|
|
return nil
|
|
}
|
|
alterSQL = fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s text", tableName, columnName)
|
|
} else {
|
|
return nil
|
|
}
|
|
|
|
if alterSQL != "" {
|
|
if err := DB.Exec(alterSQL).Error; err != nil {
|
|
return fmt.Errorf("failed to migrate %s.%s to text: %w", tableName, columnName, err)
|
|
}
|
|
common.SysLog(fmt.Sprintf("Successfully migrated %s.%s to text", tableName, columnName))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// migrateSubscriptionPlanPriceAmount migrates price_amount column from float/double to decimal(10,6)
|
|
// This is safe to run multiple times - it checks the column type first
|
|
func migrateSubscriptionPlanPriceAmount() {
|
|
// SQLite doesn't support ALTER COLUMN, and its type affinity handles this automatically
|
|
// Skip early to avoid GORM parsing the existing table DDL which may cause issues
|
|
if common.UsingSQLite {
|
|
return
|
|
}
|
|
|
|
tableName := "subscription_plans"
|
|
columnName := "price_amount"
|
|
|
|
// Check if table exists first
|
|
if !DB.Migrator().HasTable(tableName) {
|
|
return
|
|
}
|
|
|
|
// Check if column exists
|
|
if !DB.Migrator().HasColumn(&SubscriptionPlan{}, columnName) {
|
|
return
|
|
}
|
|
|
|
var alterSQL string
|
|
if common.UsingPostgreSQL {
|
|
// PostgreSQL: Check if already decimal/numeric
|
|
var dataType string
|
|
if err := DB.Raw(`SELECT data_type FROM information_schema.columns
|
|
WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
|
|
tableName, columnName).Scan(&dataType).Error; err != nil {
|
|
common.SysLog(fmt.Sprintf("Warning: failed to query metadata for %s.%s: %v", tableName, columnName, err))
|
|
} else if dataType == "numeric" {
|
|
return // Already decimal/numeric
|
|
}
|
|
alterSQL = fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s TYPE decimal(10,6) USING %s::decimal(10,6)`,
|
|
tableName, columnName, columnName)
|
|
} else if common.UsingMySQL {
|
|
// MySQL: Check if already decimal
|
|
var columnType string
|
|
if err := DB.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
|
|
WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
|
|
tableName, columnName).Scan(&columnType).Error; err != nil {
|
|
common.SysLog(fmt.Sprintf("Warning: failed to query metadata for %s.%s: %v", tableName, columnName, err))
|
|
} else if strings.HasPrefix(strings.ToLower(columnType), "decimal") {
|
|
return // Already decimal
|
|
}
|
|
alterSQL = fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s decimal(10,6) NOT NULL DEFAULT 0",
|
|
tableName, columnName)
|
|
} else {
|
|
return
|
|
}
|
|
|
|
if alterSQL != "" {
|
|
if err := DB.Exec(alterSQL).Error; err != nil {
|
|
common.SysLog(fmt.Sprintf("Warning: failed to migrate %s.%s to decimal: %v", tableName, columnName, err))
|
|
} else {
|
|
common.SysLog(fmt.Sprintf("Successfully migrated %s.%s to decimal(10,6)", tableName, columnName))
|
|
}
|
|
}
|
|
}
|
|
|
|
func closeDB(db *gorm.DB) error {
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = sqlDB.Close()
|
|
return err
|
|
}
|
|
|
|
func CloseDB() error {
|
|
if LOG_DB != DB {
|
|
err := closeDB(LOG_DB)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return closeDB(DB)
|
|
}
|
|
|
|
// checkMySQLChineseSupport ensures the MySQL connection and current schema
|
|
// default charset/collation can store Chinese characters. It allows common
|
|
// Chinese-capable charsets (utf8mb4, utf8, gbk, big5, gb18030) and panics otherwise.
|
|
func checkMySQLChineseSupport(db *gorm.DB) error {
|
|
// 仅检测:当前库默认字符集/排序规则 + 各表的排序规则(隐含字符集)
|
|
|
|
// Read current schema defaults
|
|
var schemaCharset, schemaCollation string
|
|
err := db.Raw("SELECT DEFAULT_CHARACTER_SET_NAME, DEFAULT_COLLATION_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = DATABASE()").Row().Scan(&schemaCharset, &schemaCollation)
|
|
if err != nil {
|
|
return fmt.Errorf("读取当前库默认字符集/排序规则失败 / Failed to read schema default charset/collation: %v", err)
|
|
}
|
|
|
|
toLower := func(s string) string { return strings.ToLower(s) }
|
|
// Allowed charsets that can store Chinese text
|
|
allowedCharsets := map[string]string{
|
|
"utf8mb4": "utf8mb4_",
|
|
"utf8": "utf8_",
|
|
"gbk": "gbk_",
|
|
"big5": "big5_",
|
|
"gb18030": "gb18030_",
|
|
}
|
|
isChineseCapable := func(cs, cl string) bool {
|
|
csLower := toLower(cs)
|
|
clLower := toLower(cl)
|
|
if prefix, ok := allowedCharsets[csLower]; ok {
|
|
if clLower == "" {
|
|
return true
|
|
}
|
|
return strings.HasPrefix(clLower, prefix)
|
|
}
|
|
// 如果仅提供了排序规则,尝试按排序规则前缀判断
|
|
for _, prefix := range allowedCharsets {
|
|
if strings.HasPrefix(clLower, prefix) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// 1) 当前库默认值必须支持中文
|
|
if !isChineseCapable(schemaCharset, schemaCollation) {
|
|
return fmt.Errorf("当前库默认字符集/排序规则不支持中文:schema(%s/%s)。请将库设置为 utf8mb4/utf8/gbk/big5/gb18030 / Schema default charset/collation is not Chinese-capable: schema(%s/%s). Please set to utf8mb4/utf8/gbk/big5/gb18030",
|
|
schemaCharset, schemaCollation, schemaCharset, schemaCollation)
|
|
}
|
|
|
|
// 2) 所有物理表的排序规则(隐含字符集)必须支持中文
|
|
type tableInfo struct {
|
|
Name string
|
|
Collation *string
|
|
}
|
|
var tables []tableInfo
|
|
if err := db.Raw("SELECT TABLE_NAME, TABLE_COLLATION FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'").Scan(&tables).Error; err != nil {
|
|
return fmt.Errorf("读取表排序规则失败 / Failed to read table collations: %v", err)
|
|
}
|
|
|
|
var badTables []string
|
|
for _, t := range tables {
|
|
// NULL 或空表示继承库默认设置,已在上面校验库默认,视为通过
|
|
if t.Collation == nil || *t.Collation == "" {
|
|
continue
|
|
}
|
|
cl := *t.Collation
|
|
// 仅凭排序规则判断是否中文可用
|
|
ok := false
|
|
lower := strings.ToLower(cl)
|
|
for _, prefix := range allowedCharsets {
|
|
if strings.HasPrefix(lower, prefix) {
|
|
ok = true
|
|
break
|
|
}
|
|
}
|
|
if !ok {
|
|
badTables = append(badTables, fmt.Sprintf("%s(%s)", t.Name, cl))
|
|
}
|
|
}
|
|
|
|
if len(badTables) > 0 {
|
|
// 限制输出数量以避免日志过长
|
|
maxShow := 20
|
|
shown := badTables
|
|
if len(shown) > maxShow {
|
|
shown = shown[:maxShow]
|
|
}
|
|
return fmt.Errorf(
|
|
"存在不支持中文的表,请修复其排序规则/字符集。示例(最多展示 %d 项):%v / Found tables not Chinese-capable. Please fix their collation/charset. Examples (showing up to %d): %v",
|
|
maxShow, shown, maxShow, shown,
|
|
)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
lastPingTime time.Time
|
|
pingMutex sync.Mutex
|
|
)
|
|
|
|
func PingDB() error {
|
|
pingMutex.Lock()
|
|
defer pingMutex.Unlock()
|
|
|
|
if time.Since(lastPingTime) < time.Second*10 {
|
|
return nil
|
|
}
|
|
|
|
sqlDB, err := DB.DB()
|
|
if err != nil {
|
|
log.Printf("Error getting sql.DB from GORM: %v", err)
|
|
return err
|
|
}
|
|
|
|
err = sqlDB.Ping()
|
|
if err != nil {
|
|
log.Printf("Error pinging DB: %v", err)
|
|
return err
|
|
}
|
|
|
|
lastPingTime = time.Now()
|
|
common.SysLog("Database pinged successfully")
|
|
return nil
|
|
}
|