393 lines
12 KiB
Go
393 lines
12 KiB
Go
package controller
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/model"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/glebarez/sqlite"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type tokenAPIResponse struct {
|
|
Success bool `json:"success"`
|
|
Message string `json:"message"`
|
|
Data json.RawMessage `json:"data"`
|
|
}
|
|
|
|
type tokenPageResponse struct {
|
|
Items []tokenResponseItem `json:"items"`
|
|
}
|
|
|
|
type tokenResponseItem struct {
|
|
ID int `json:"id"`
|
|
Name string `json:"name"`
|
|
Key string `json:"key"`
|
|
Status int `json:"status"`
|
|
}
|
|
|
|
type tokenKeyResponse struct {
|
|
Key string `json:"key"`
|
|
}
|
|
|
|
type sqliteColumnInfo struct {
|
|
Name string `gorm:"column:name"`
|
|
Type string `gorm:"column:type"`
|
|
}
|
|
|
|
type legacyToken struct {
|
|
Id int `gorm:"primaryKey"`
|
|
UserId int `gorm:"index"`
|
|
Key string `gorm:"column:key;type:char(48);uniqueIndex"`
|
|
Status int `gorm:"default:1"`
|
|
Name string `gorm:"index"`
|
|
CreatedTime int64 `gorm:"bigint"`
|
|
AccessedTime int64 `gorm:"bigint"`
|
|
ExpiredTime int64 `gorm:"bigint;default:-1"`
|
|
RemainQuota int `gorm:"default:0"`
|
|
UnlimitedQuota bool
|
|
ModelLimitsEnabled bool
|
|
ModelLimits string `gorm:"type:text"`
|
|
AllowIps *string `gorm:"default:''"`
|
|
UsedQuota int `gorm:"default:0"`
|
|
Group string `gorm:"column:group;default:''"`
|
|
CrossGroupRetry bool
|
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
|
}
|
|
|
|
func (legacyToken) TableName() string {
|
|
return "tokens"
|
|
}
|
|
|
|
func openTokenControllerTestDB(t *testing.T) *gorm.DB {
|
|
t.Helper()
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
common.UsingSQLite = true
|
|
common.UsingMySQL = false
|
|
common.UsingPostgreSQL = false
|
|
common.RedisEnabled = false
|
|
|
|
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
|
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
|
if err != nil {
|
|
t.Fatalf("failed to open sqlite db: %v", err)
|
|
}
|
|
model.DB = db
|
|
model.LOG_DB = db
|
|
|
|
t.Cleanup(func() {
|
|
sqlDB, err := db.DB()
|
|
if err == nil {
|
|
_ = sqlDB.Close()
|
|
}
|
|
})
|
|
|
|
return db
|
|
}
|
|
|
|
func migrateTokenControllerTestDB(t *testing.T, db *gorm.DB) {
|
|
t.Helper()
|
|
|
|
if err := db.AutoMigrate(&model.Token{}); err != nil {
|
|
t.Fatalf("failed to migrate token table: %v", err)
|
|
}
|
|
}
|
|
|
|
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
|
t.Helper()
|
|
|
|
db := openTokenControllerTestDB(t)
|
|
migrateTokenControllerTestDB(t, db)
|
|
return db
|
|
}
|
|
|
|
func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
|
|
t.Helper()
|
|
|
|
token := &model.Token{
|
|
UserId: userID,
|
|
Name: name,
|
|
Key: rawKey,
|
|
Status: common.TokenStatusEnabled,
|
|
CreatedTime: 1,
|
|
AccessedTime: 1,
|
|
ExpiredTime: -1,
|
|
RemainQuota: 100,
|
|
UnlimitedQuota: true,
|
|
Group: "default",
|
|
}
|
|
if err := db.Create(token).Error; err != nil {
|
|
t.Fatalf("failed to create token: %v", err)
|
|
}
|
|
return token
|
|
}
|
|
|
|
func newAuthenticatedContext(t *testing.T, method string, target string, body any, userID int) (*gin.Context, *httptest.ResponseRecorder) {
|
|
t.Helper()
|
|
|
|
var requestBody *bytes.Reader
|
|
if body != nil {
|
|
payload, err := common.Marshal(body)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal request body: %v", err)
|
|
}
|
|
requestBody = bytes.NewReader(payload)
|
|
} else {
|
|
requestBody = bytes.NewReader(nil)
|
|
}
|
|
|
|
recorder := httptest.NewRecorder()
|
|
ctx, _ := gin.CreateTestContext(recorder)
|
|
ctx.Request = httptest.NewRequest(method, target, requestBody)
|
|
if body != nil {
|
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
|
}
|
|
ctx.Set("id", userID)
|
|
return ctx, recorder
|
|
}
|
|
|
|
func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenAPIResponse {
|
|
t.Helper()
|
|
|
|
var response tokenAPIResponse
|
|
if err := common.Unmarshal(recorder.Body.Bytes(), &response); err != nil {
|
|
t.Fatalf("failed to decode api response: %v", err)
|
|
}
|
|
return response
|
|
}
|
|
|
|
func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName string) string {
|
|
t.Helper()
|
|
|
|
var columns []sqliteColumnInfo
|
|
if err := db.Raw("PRAGMA table_info(" + tableName + ")").Scan(&columns).Error; err != nil {
|
|
t.Fatalf("failed to inspect %s schema: %v", tableName, err)
|
|
}
|
|
|
|
for _, column := range columns {
|
|
if column.Name == columnName {
|
|
return strings.ToLower(column.Type)
|
|
}
|
|
}
|
|
|
|
t.Fatalf("column %s not found in %s schema", columnName, tableName)
|
|
return ""
|
|
}
|
|
|
|
func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
|
|
db := setupTokenControllerTestDB(t)
|
|
|
|
if got := getSQLiteColumnType(t, db, "tokens", "key"); got != "varchar(128)" {
|
|
t.Fatalf("expected key column type varchar(128), got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
|
|
db := openTokenControllerTestDB(t)
|
|
legacyKey := strings.Repeat("a", 48)
|
|
|
|
if err := db.AutoMigrate(&legacyToken{}); err != nil {
|
|
t.Fatalf("failed to create legacy token schema: %v", err)
|
|
}
|
|
if err := db.Create(&legacyToken{
|
|
Id: 1,
|
|
UserId: 7,
|
|
Key: legacyKey,
|
|
Status: common.TokenStatusEnabled,
|
|
Name: "legacy-token",
|
|
CreatedTime: 1,
|
|
AccessedTime: 1,
|
|
ExpiredTime: -1,
|
|
RemainQuota: 100,
|
|
UnlimitedQuota: true,
|
|
ModelLimitsEnabled: false,
|
|
ModelLimits: "",
|
|
AllowIps: common.GetPointer(""),
|
|
UsedQuota: 0,
|
|
Group: "default",
|
|
CrossGroupRetry: false,
|
|
}).Error; err != nil {
|
|
t.Fatalf("failed to seed legacy token row: %v", err)
|
|
}
|
|
|
|
if got := getSQLiteColumnType(t, db, "tokens", "key"); got != "char(48)" {
|
|
t.Fatalf("expected legacy key column type char(48), got %q", got)
|
|
}
|
|
|
|
migrateTokenControllerTestDB(t, db)
|
|
|
|
if got := getSQLiteColumnType(t, db, "tokens", "key"); got != "varchar(128)" {
|
|
t.Fatalf("expected migrated key column type varchar(128), got %q", got)
|
|
}
|
|
|
|
var migratedToken model.Token
|
|
if err := db.First(&migratedToken, "id = ?", 1).Error; err != nil {
|
|
t.Fatalf("failed to load migrated token row: %v", err)
|
|
}
|
|
if migratedToken.Key != legacyKey {
|
|
t.Fatalf("expected migrated token key %q, got %q", legacyKey, migratedToken.Key)
|
|
}
|
|
if migratedToken.Name != "legacy-token" {
|
|
t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
|
|
}
|
|
}
|
|
|
|
func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
|
|
db := setupTokenControllerTestDB(t)
|
|
token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
|
|
seedToken(t, db, 2, "other-user-token", "zzzz1234yyyy5678")
|
|
|
|
ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/?p=1&size=10", nil, 1)
|
|
GetAllTokens(ctx)
|
|
|
|
response := decodeAPIResponse(t, recorder)
|
|
if !response.Success {
|
|
t.Fatalf("expected success response, got message: %s", response.Message)
|
|
}
|
|
|
|
var page tokenPageResponse
|
|
if err := common.Unmarshal(response.Data, &page); err != nil {
|
|
t.Fatalf("failed to decode token page response: %v", err)
|
|
}
|
|
if len(page.Items) != 1 {
|
|
t.Fatalf("expected exactly one token, got %d", len(page.Items))
|
|
}
|
|
if page.Items[0].Key != token.GetMaskedKey() {
|
|
t.Fatalf("expected masked key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
|
|
}
|
|
if strings.Contains(recorder.Body.String(), token.Key) {
|
|
t.Fatalf("list response leaked raw token key: %s", recorder.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestSearchTokensMasksKeyInResponse(t *testing.T) {
|
|
db := setupTokenControllerTestDB(t)
|
|
token := seedToken(t, db, 1, "searchable-token", "ijkl1234mnop5678")
|
|
|
|
ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/search?keyword=searchable-token&p=1&size=10", nil, 1)
|
|
SearchTokens(ctx)
|
|
|
|
response := decodeAPIResponse(t, recorder)
|
|
if !response.Success {
|
|
t.Fatalf("expected success response, got message: %s", response.Message)
|
|
}
|
|
|
|
var page tokenPageResponse
|
|
if err := common.Unmarshal(response.Data, &page); err != nil {
|
|
t.Fatalf("failed to decode search response: %v", err)
|
|
}
|
|
if len(page.Items) != 1 {
|
|
t.Fatalf("expected exactly one search result, got %d", len(page.Items))
|
|
}
|
|
if page.Items[0].Key != token.GetMaskedKey() {
|
|
t.Fatalf("expected masked search key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
|
|
}
|
|
if strings.Contains(recorder.Body.String(), token.Key) {
|
|
t.Fatalf("search response leaked raw token key: %s", recorder.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestGetTokenMasksKeyInResponse(t *testing.T) {
|
|
db := setupTokenControllerTestDB(t)
|
|
token := seedToken(t, db, 1, "detail-token", "qrst1234uvwx5678")
|
|
|
|
ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/"+strconv.Itoa(token.Id), nil, 1)
|
|
ctx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
|
|
GetToken(ctx)
|
|
|
|
response := decodeAPIResponse(t, recorder)
|
|
if !response.Success {
|
|
t.Fatalf("expected success response, got message: %s", response.Message)
|
|
}
|
|
|
|
var detail tokenResponseItem
|
|
if err := common.Unmarshal(response.Data, &detail); err != nil {
|
|
t.Fatalf("failed to decode token detail response: %v", err)
|
|
}
|
|
if detail.Key != token.GetMaskedKey() {
|
|
t.Fatalf("expected masked detail key %q, got %q", token.GetMaskedKey(), detail.Key)
|
|
}
|
|
if strings.Contains(recorder.Body.String(), token.Key) {
|
|
t.Fatalf("detail response leaked raw token key: %s", recorder.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestUpdateTokenMasksKeyInResponse(t *testing.T) {
|
|
db := setupTokenControllerTestDB(t)
|
|
token := seedToken(t, db, 1, "editable-token", "yzab1234cdef5678")
|
|
|
|
body := map[string]any{
|
|
"id": token.Id,
|
|
"name": "updated-token",
|
|
"expired_time": -1,
|
|
"remain_quota": 100,
|
|
"unlimited_quota": true,
|
|
"model_limits_enabled": false,
|
|
"model_limits": "",
|
|
"group": "default",
|
|
"cross_group_retry": false,
|
|
}
|
|
|
|
ctx, recorder := newAuthenticatedContext(t, http.MethodPut, "/api/token/", body, 1)
|
|
UpdateToken(ctx)
|
|
|
|
response := decodeAPIResponse(t, recorder)
|
|
if !response.Success {
|
|
t.Fatalf("expected success response, got message: %s", response.Message)
|
|
}
|
|
|
|
var detail tokenResponseItem
|
|
if err := common.Unmarshal(response.Data, &detail); err != nil {
|
|
t.Fatalf("failed to decode token update response: %v", err)
|
|
}
|
|
if detail.Key != token.GetMaskedKey() {
|
|
t.Fatalf("expected masked update key %q, got %q", token.GetMaskedKey(), detail.Key)
|
|
}
|
|
if strings.Contains(recorder.Body.String(), token.Key) {
|
|
t.Fatalf("update response leaked raw token key: %s", recorder.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestGetTokenKeyRequiresOwnershipAndReturnsFullKey(t *testing.T) {
|
|
db := setupTokenControllerTestDB(t)
|
|
token := seedToken(t, db, 1, "owned-token", "owner1234token5678")
|
|
|
|
authorizedCtx, authorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 1)
|
|
authorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
|
|
GetTokenKey(authorizedCtx)
|
|
|
|
authorizedResponse := decodeAPIResponse(t, authorizedRecorder)
|
|
if !authorizedResponse.Success {
|
|
t.Fatalf("expected authorized key fetch to succeed, got message: %s", authorizedResponse.Message)
|
|
}
|
|
|
|
var keyData tokenKeyResponse
|
|
if err := common.Unmarshal(authorizedResponse.Data, &keyData); err != nil {
|
|
t.Fatalf("failed to decode token key response: %v", err)
|
|
}
|
|
if keyData.Key != token.GetFullKey() {
|
|
t.Fatalf("expected full key %q, got %q", token.GetFullKey(), keyData.Key)
|
|
}
|
|
|
|
unauthorizedCtx, unauthorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 2)
|
|
unauthorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
|
|
GetTokenKey(unauthorizedCtx)
|
|
|
|
unauthorizedResponse := decodeAPIResponse(t, unauthorizedRecorder)
|
|
if unauthorizedResponse.Success {
|
|
t.Fatalf("expected unauthorized key fetch to fail")
|
|
}
|
|
if strings.Contains(unauthorizedRecorder.Body.String(), token.Key) {
|
|
t.Fatalf("unauthorized key response leaked raw token key: %s", unauthorizedRecorder.Body.String())
|
|
}
|
|
}
|