215 lines
4.7 KiB
Go
215 lines
4.7 KiB
Go
package database
|
|
|
|
import (
|
|
"bufio"
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
|
|
"sunhpc/pkg/config"
|
|
"sunhpc/pkg/logger"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
// =========================================================
|
|
// 全局变量
|
|
// =========================================================
|
|
var (
|
|
dbInstance *sql.DB
|
|
dbOnce sync.Once
|
|
dbMutex sync.RWMutex
|
|
dbErr error
|
|
)
|
|
|
|
// =========================================================
|
|
// GetDB - 获取数据库连接(单例模式)
|
|
// =========================================================
|
|
func GetDB() (*sql.DB, error) {
|
|
dbOnce.Do(func() {
|
|
if dbInstance != nil {
|
|
return
|
|
}
|
|
|
|
// 确保配置已加载
|
|
cfg, err := config.LoadConfig()
|
|
if err != nil {
|
|
dbErr = fmt.Errorf("加载配置失败: %w", err)
|
|
return
|
|
}
|
|
|
|
// 构建DSN
|
|
logger.Debugf("DSN: %s", cfg.Database.DSN)
|
|
|
|
// 打开SQLite 连接
|
|
sqlDB, err := sql.Open("sqlite3", cfg.Database.DSN)
|
|
if err != nil {
|
|
dbErr = fmt.Errorf("数据库打开失败: %w", err)
|
|
return
|
|
}
|
|
|
|
// 设置连接池参数
|
|
sqlDB.SetMaxOpenConns(10) // 最大打开连接数
|
|
sqlDB.SetMaxIdleConns(5) // 保持空闲连接
|
|
sqlDB.SetConnMaxLifetime(0) // 禁用连接生命周期超时
|
|
sqlDB.SetConnMaxIdleTime(0) // 禁用空闲连接超时
|
|
|
|
// 测试数据库连接
|
|
if err := sqlDB.Ping(); err != nil {
|
|
sqlDB.Close()
|
|
dbErr = fmt.Errorf("数据库连接失败: %w", err)
|
|
return
|
|
}
|
|
|
|
logger.Debug("数据库连接成功")
|
|
dbInstance = sqlDB
|
|
})
|
|
|
|
if dbErr != nil {
|
|
return nil, dbErr
|
|
}
|
|
|
|
return dbInstance, nil
|
|
}
|
|
|
|
func confirmAction(prompt string) bool {
|
|
reader := bufio.NewReader(os.Stdin)
|
|
|
|
logger.Warnf("%s [Y/Yes]: ", prompt)
|
|
response, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
response = strings.ToLower(strings.TrimSpace(response))
|
|
return response == "y" || response == "yes"
|
|
}
|
|
|
|
func InitTables(db *sql.DB, force bool) error {
|
|
|
|
if force {
|
|
// 确认是否强制删除
|
|
if !confirmAction("确认强制删除所有表和触发器?") {
|
|
logger.Info("操作已取消")
|
|
db.Close()
|
|
os.Exit(0)
|
|
return nil
|
|
}
|
|
|
|
// 强制删除所有表和触发器
|
|
logger.Debug("强制删除所有表和触发器...")
|
|
if err := dropTables(db); err != nil {
|
|
return fmt.Errorf("删除表失败: %w", err)
|
|
}
|
|
logger.Debug("删除所有表和触发器成功")
|
|
|
|
if err := dropTriggers(db); err != nil {
|
|
return fmt.Errorf("删除触发器失败: %w", err)
|
|
}
|
|
logger.Debug("删除所有触发器成功")
|
|
}
|
|
|
|
// ✅ 调用 schema.go 中的函数
|
|
for _, ddl := range CreateTableStatements() {
|
|
logger.Debugf("执行: %s", ddl)
|
|
if _, err := db.Exec(ddl); err != nil {
|
|
return fmt.Errorf("数据表创建失败: %w", err)
|
|
}
|
|
}
|
|
logger.Info("数据库表创建成功")
|
|
/*
|
|
使用sqlite3命令 测试数据库是否存在表
|
|
✅ 查询所有表
|
|
sqlite3 /var/lib/sunhpc/sunhpc.db
|
|
.tables # 查看所有表
|
|
select * from sqlite_master where type='table'; # 查看表定义
|
|
PRAGMA integrity_check; # 检查数据库完整性
|
|
*/
|
|
return nil
|
|
}
|
|
|
|
func dropTables(db *sql.DB) error {
|
|
// ✅ 调用 schema.go 中的函数
|
|
for _, table := range DropTableOrder() {
|
|
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func dropTriggers(db *sql.DB) error {
|
|
// ✅ 调用 schema.go 中的函数
|
|
for _, trigger := range DropTriggerStatements() {
|
|
if _, err := db.Exec(fmt.Sprintf("DROP TRIGGER IF EXISTS `%s`", trigger)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func CloseDB() error {
|
|
dbMutex.Lock()
|
|
defer dbMutex.Unlock()
|
|
|
|
if dbInstance == nil {
|
|
if err := dbInstance.Close(); err != nil {
|
|
return err
|
|
}
|
|
dbInstance = nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// 使用事务回滚测试
|
|
func RunTestWithRollback(db *sql.DB, testFunc func(*sql.Tx) error) error {
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 执行测试
|
|
if err := testFunc(tx); err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
// 回滚事务,所有更改(包括 ID 递增)都会撤销
|
|
return tx.Rollback()
|
|
}
|
|
|
|
// 使用示例
|
|
func TestNodeInsert(db *sql.DB) error {
|
|
logger.Debug("测试数据插入...")
|
|
return RunTestWithRollback(db, func(tx *sql.Tx) error {
|
|
// 插入测试数据
|
|
logger.Debug("执行插入测试数据...")
|
|
|
|
_, err := tx.Exec(`
|
|
INSERT INTO nodes (name, cpus, rack, rank)
|
|
VALUES (?, ?, ?, ?)
|
|
`, "test-node", 64, 1, 1)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 验证插入
|
|
var count int
|
|
logger.Debug("执行查询测试数据...")
|
|
err = tx.QueryRow(`
|
|
SELECT COUNT(*) FROM nodes WHERE name = ?
|
|
`, "test-node").Scan(&count)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
logger.Infof("测试数据插入成功,共 %d 条", count)
|
|
|
|
// 不需要手动删除,回滚会自动撤销
|
|
return nil
|
|
})
|
|
}
|