diff --git a/controller/token_test.go b/controller/token_test.go index 7ffb38b4..0c0f504b 100644 --- a/controller/token_test.go +++ b/controller/token_test.go @@ -114,7 +114,7 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB { return db } -func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) *gorm.DB { +func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) (*gorm.DB, *bool) { t.Helper() gin.SetMode(gin.TestMode) @@ -142,15 +142,23 @@ func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) *go model.DB = db model.LOG_DB = db + if db.Migrator().HasTable("tokens") { + t.Skipf("refusing to run %s migration compatibility test against external database because tokens table already exists", dialect) + } + + managedTokensTable := new(bool) + t.Cleanup(func() { - _ = db.Exec("DROP TABLE IF EXISTS tokens").Error + if *managedTokensTable && db.Migrator().HasTable("tokens") { + _ = db.Migrator().DropTable("tokens") + } sqlDB, err := db.DB() if err == nil { _ = sqlDB.Close() } }) - return db + return db, managedTokensTable } func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token { @@ -266,18 +274,18 @@ func getTokenKeyColumnType(t *testing.T, db *gorm.DB, dialect string) string { } } -func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string) { +func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string, managedTokensTable *bool) { t.Helper() legacyKey := strings.Repeat("a", 48) longKey := strings.Repeat("b", 64) - if err := db.Exec("DROP TABLE IF EXISTS tokens").Error; err != nil { - t.Fatalf("failed to drop pre-existing token table: %v", err) - } if err := db.AutoMigrate(&legacyToken{}); err != nil { t.Fatalf("failed to create legacy token schema: %v", err) } + if managedTokensTable != nil { + *managedTokensTable = true + } if err := db.Create(&legacyToken{ UserId: 7, Key: legacyKey, @@ -359,7 +367,7 @@ func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) { func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) { db := openTokenControllerTestDB(t) - runTokenMigrationCompatibilityTest(t, db, "sqlite") + runTokenMigrationCompatibilityTest(t, db, "sqlite", nil) } func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) { @@ -368,8 +376,8 @@ func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) { t.Skip("set TEST_MYSQL_DSN to run mysql migration compatibility test") } - db := openTokenControllerExternalDB(t, "mysql", dsn) - runTokenMigrationCompatibilityTest(t, db, "mysql") + db, managedTokensTable := openTokenControllerExternalDB(t, "mysql", dsn) + runTokenMigrationCompatibilityTest(t, db, "mysql", managedTokensTable) } func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) { @@ -378,8 +386,8 @@ func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) { t.Skip("set TEST_POSTGRES_DSN to run postgres migration compatibility test") } - db := openTokenControllerExternalDB(t, "postgres", dsn) - runTokenMigrationCompatibilityTest(t, db, "postgres") + db, managedTokensTable := openTokenControllerExternalDB(t, "postgres", dsn) + runTokenMigrationCompatibilityTest(t, db, "postgres", managedTokensTable) } func TestGetAllTokensMasksKeyInResponse(t *testing.T) {