Merge pull request #4401 from XiaoAI1024/codex/legacy-token-key-compat
Relax token key column length for legacy migration compatibility
This commit is contained in:
+271
-5
@@ -2,10 +2,12 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -14,6 +16,8 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/model"
|
"github.com/QuantumNous/new-api/model"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/glebarez/sqlite"
|
"github.com/glebarez/sqlite"
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,7 +42,36 @@ type tokenKeyResponse struct {
|
|||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
type sqliteColumnInfo struct {
|
||||||
|
Name string `gorm:"column:name"`
|
||||||
|
Type string `gorm:"column:type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type legacyToken struct {
|
||||||
|
Id int `gorm:"primaryKey"`
|
||||||
|
UserId int `gorm:"index"`
|
||||||
|
Key string `gorm:"column:key;type:char(48);uniqueIndex"`
|
||||||
|
Status int `gorm:"default:1"`
|
||||||
|
Name string `gorm:"index"`
|
||||||
|
CreatedTime int64 `gorm:"bigint"`
|
||||||
|
AccessedTime int64 `gorm:"bigint"`
|
||||||
|
ExpiredTime int64 `gorm:"bigint;default:-1"`
|
||||||
|
RemainQuota int `gorm:"default:0"`
|
||||||
|
UnlimitedQuota bool
|
||||||
|
ModelLimitsEnabled bool
|
||||||
|
ModelLimits string `gorm:"type:text"`
|
||||||
|
AllowIps *string `gorm:"default:''"`
|
||||||
|
UsedQuota int `gorm:"default:0"`
|
||||||
|
Group string `gorm:"column:group;default:''"`
|
||||||
|
CrossGroupRetry bool
|
||||||
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (legacyToken) TableName() string {
|
||||||
|
return "tokens"
|
||||||
|
}
|
||||||
|
|
||||||
|
func openTokenControllerTestDB(t *testing.T) *gorm.DB {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
@@ -55,10 +88,6 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
|||||||
model.DB = db
|
model.DB = db
|
||||||
model.LOG_DB = db
|
model.LOG_DB = db
|
||||||
|
|
||||||
if err := db.AutoMigrate(&model.Token{}); err != nil {
|
|
||||||
t.Fatalf("failed to migrate token table: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
sqlDB, err := db.DB()
|
sqlDB, err := db.DB()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -69,6 +98,69 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
|||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func migrateTokenControllerTestDB(t *testing.T, db *gorm.DB) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if err := db.AutoMigrate(&model.Token{}); err != nil {
|
||||||
|
t.Fatalf("failed to migrate token table: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db := openTokenControllerTestDB(t)
|
||||||
|
migrateTokenControllerTestDB(t, db)
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) (*gorm.DB, *bool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
common.RedisEnabled = false
|
||||||
|
common.UsingSQLite = false
|
||||||
|
common.UsingMySQL = dialect == "mysql"
|
||||||
|
common.UsingPostgreSQL = dialect == "postgres"
|
||||||
|
|
||||||
|
var (
|
||||||
|
db *gorm.DB
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
switch dialect {
|
||||||
|
case "mysql":
|
||||||
|
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||||
|
case "postgres":
|
||||||
|
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||||
|
default:
|
||||||
|
t.Fatalf("unsupported dialect %q", dialect)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open %s db: %v", dialect, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
model.DB = db
|
||||||
|
model.LOG_DB = db
|
||||||
|
|
||||||
|
if db.Migrator().HasTable("tokens") {
|
||||||
|
t.Skipf("refusing to run %s migration compatibility test against external database because tokens table already exists", dialect)
|
||||||
|
}
|
||||||
|
|
||||||
|
managedTokensTable := new(bool)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if *managedTokensTable && db.Migrator().HasTable("tokens") {
|
||||||
|
_ = db.Migrator().DropTable("tokens")
|
||||||
|
}
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err == nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return db, managedTokensTable
|
||||||
|
}
|
||||||
|
|
||||||
func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
|
func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -124,6 +216,180 @@ func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenA
|
|||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var columns []sqliteColumnInfo
|
||||||
|
if err := db.Raw("PRAGMA table_info(" + tableName + ")").Scan(&columns).Error; err != nil {
|
||||||
|
t.Fatalf("failed to inspect %s schema: %v", tableName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, column := range columns {
|
||||||
|
if column.Name == columnName {
|
||||||
|
return strings.ToLower(column.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatalf("column %s not found in %s schema", columnName, tableName)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTokenKeyColumnType(t *testing.T, db *gorm.DB, dialect string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
switch dialect {
|
||||||
|
case "sqlite":
|
||||||
|
return getSQLiteColumnType(t, db, "tokens", "key")
|
||||||
|
case "mysql":
|
||||||
|
var columnType string
|
||||||
|
if err := db.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
|
||||||
|
WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
|
||||||
|
"tokens", "key").Scan(&columnType).Error; err != nil {
|
||||||
|
t.Fatalf("failed to inspect mysql token key column: %v", err)
|
||||||
|
}
|
||||||
|
return strings.ToLower(columnType)
|
||||||
|
case "postgres":
|
||||||
|
var dataType string
|
||||||
|
var maxLength sql.NullInt64
|
||||||
|
if err := db.Raw(`SELECT data_type, character_maximum_length
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
|
||||||
|
"tokens", "key").Row().Scan(&dataType, &maxLength); err != nil {
|
||||||
|
t.Fatalf("failed to inspect postgres token key column: %v", err)
|
||||||
|
}
|
||||||
|
switch strings.ToLower(dataType) {
|
||||||
|
case "character varying":
|
||||||
|
return fmt.Sprintf("varchar(%d)", maxLength.Int64)
|
||||||
|
case "character":
|
||||||
|
return fmt.Sprintf("char(%d)", maxLength.Int64)
|
||||||
|
default:
|
||||||
|
if maxLength.Valid {
|
||||||
|
return fmt.Sprintf("%s(%d)", strings.ToLower(dataType), maxLength.Int64)
|
||||||
|
}
|
||||||
|
return strings.ToLower(dataType)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Fatalf("unsupported dialect %q", dialect)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string, managedTokensTable *bool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
legacyKey := strings.Repeat("a", 48)
|
||||||
|
longKey := strings.Repeat("b", 64)
|
||||||
|
|
||||||
|
if err := db.AutoMigrate(&legacyToken{}); err != nil {
|
||||||
|
t.Fatalf("failed to create legacy token schema: %v", err)
|
||||||
|
}
|
||||||
|
if managedTokensTable != nil {
|
||||||
|
*managedTokensTable = true
|
||||||
|
}
|
||||||
|
if err := db.Create(&legacyToken{
|
||||||
|
UserId: 7,
|
||||||
|
Key: legacyKey,
|
||||||
|
Status: common.TokenStatusEnabled,
|
||||||
|
Name: "legacy-token",
|
||||||
|
CreatedTime: 1,
|
||||||
|
AccessedTime: 1,
|
||||||
|
ExpiredTime: -1,
|
||||||
|
RemainQuota: 100,
|
||||||
|
UnlimitedQuota: true,
|
||||||
|
ModelLimitsEnabled: false,
|
||||||
|
ModelLimits: "",
|
||||||
|
AllowIps: common.GetPointer(""),
|
||||||
|
UsedQuota: 0,
|
||||||
|
Group: "default",
|
||||||
|
CrossGroupRetry: false,
|
||||||
|
}).Error; err != nil {
|
||||||
|
t.Fatalf("failed to seed legacy token row: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := getTokenKeyColumnType(t, db, dialect); got != "char(48)" {
|
||||||
|
t.Fatalf("expected legacy key column type char(48), got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrateTokenControllerTestDB(t, db)
|
||||||
|
|
||||||
|
if got := getTokenKeyColumnType(t, db, dialect); got != "varchar(128)" {
|
||||||
|
t.Fatalf("expected migrated key column type varchar(128), got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
var migratedToken model.Token
|
||||||
|
if err := db.First(&migratedToken, "name = ?", "legacy-token").Error; err != nil {
|
||||||
|
t.Fatalf("failed to load migrated token row: %v", err)
|
||||||
|
}
|
||||||
|
if migratedToken.Key != legacyKey {
|
||||||
|
t.Fatalf("expected migrated token key %q, got %q", legacyKey, migratedToken.Key)
|
||||||
|
}
|
||||||
|
if migratedToken.Name != "legacy-token" {
|
||||||
|
t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
inserted := model.Token{
|
||||||
|
UserId: 8,
|
||||||
|
Name: "long-token",
|
||||||
|
Key: longKey,
|
||||||
|
Status: common.TokenStatusEnabled,
|
||||||
|
CreatedTime: 1,
|
||||||
|
AccessedTime: 1,
|
||||||
|
ExpiredTime: -1,
|
||||||
|
RemainQuota: 200,
|
||||||
|
UnlimitedQuota: true,
|
||||||
|
ModelLimitsEnabled: false,
|
||||||
|
ModelLimits: "",
|
||||||
|
AllowIps: common.GetPointer(""),
|
||||||
|
UsedQuota: 0,
|
||||||
|
Group: "default",
|
||||||
|
CrossGroupRetry: false,
|
||||||
|
}
|
||||||
|
if err := db.Create(&inserted).Error; err != nil {
|
||||||
|
t.Fatalf("failed to insert long token after migration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fetched model.Token
|
||||||
|
if err := db.First(&fetched, "id = ?", inserted.Id).Error; err != nil {
|
||||||
|
t.Fatalf("failed to fetch long token after migration: %v", err)
|
||||||
|
}
|
||||||
|
if fetched.Key != longKey {
|
||||||
|
t.Fatalf("expected long token key %q, got %q", longKey, fetched.Key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
|
||||||
|
db := setupTokenControllerTestDB(t)
|
||||||
|
|
||||||
|
if got := getTokenKeyColumnType(t, db, "sqlite"); got != "varchar(128)" {
|
||||||
|
t.Fatalf("expected key column type varchar(128), got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
|
||||||
|
db := openTokenControllerTestDB(t)
|
||||||
|
runTokenMigrationCompatibilityTest(t, db, "sqlite", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) {
|
||||||
|
dsn := os.Getenv("TEST_MYSQL_DSN")
|
||||||
|
if dsn == "" {
|
||||||
|
t.Skip("set TEST_MYSQL_DSN to run mysql migration compatibility test")
|
||||||
|
}
|
||||||
|
|
||||||
|
db, managedTokensTable := openTokenControllerExternalDB(t, "mysql", dsn)
|
||||||
|
runTokenMigrationCompatibilityTest(t, db, "mysql", managedTokensTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) {
|
||||||
|
dsn := os.Getenv("TEST_POSTGRES_DSN")
|
||||||
|
if dsn == "" {
|
||||||
|
t.Skip("set TEST_POSTGRES_DSN to run postgres migration compatibility test")
|
||||||
|
}
|
||||||
|
|
||||||
|
db, managedTokensTable := openTokenControllerExternalDB(t, "postgres", dsn)
|
||||||
|
runTokenMigrationCompatibilityTest(t, db, "postgres", managedTokensTable)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
|
func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
|
||||||
db := setupTokenControllerTestDB(t)
|
db := setupTokenControllerTestDB(t)
|
||||||
token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
|
token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
|
||||||
|
|||||||
+1
-1
@@ -14,7 +14,7 @@ import (
|
|||||||
type Token struct {
|
type Token struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
UserId int `json:"user_id" gorm:"index"`
|
UserId int `json:"user_id" gorm:"index"`
|
||||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
Key string `json:"key" gorm:"type:varchar(128);uniqueIndex"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index" `
|
Name string `json:"name" gorm:"index" `
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
|
|||||||
Reference in New Issue
Block a user