package database import ( "bufio" "database/sql" "fmt" "os" "strings" "sync" "sunhpc/pkg/config" "sunhpc/pkg/logger" _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" ) // ========================================================= // 全局变量 // ========================================================= var ( dbInstance *sql.DB dbOnce sync.Once dbMutex sync.RWMutex dbErr error ) // ========================================================= // GetDB - 获取数据库连接(单例模式) // ========================================================= func GetDB() (*sql.DB, error) { dbOnce.Do(func() { if dbInstance != nil { return } // 确保配置已加载 cfg, err := config.LoadConfig() if err != nil { dbErr = fmt.Errorf("加载配置失败: %w", err) return } // 构建DSN logger.Debugf("DSN: %s", cfg.Database.DSN) // 打开SQLite 连接 sqlDB, err := sql.Open("sqlite3", cfg.Database.DSN) if err != nil { dbErr = fmt.Errorf("数据库打开失败: %w", err) return } // 设置连接池参数 sqlDB.SetMaxOpenConns(10) // 最大打开连接数 sqlDB.SetMaxIdleConns(5) // 保持空闲连接 sqlDB.SetConnMaxLifetime(0) // 禁用连接生命周期超时 sqlDB.SetConnMaxIdleTime(0) // 禁用空闲连接超时 // 测试数据库连接 if err := sqlDB.Ping(); err != nil { sqlDB.Close() dbErr = fmt.Errorf("数据库连接失败: %w", err) return } logger.Debug("数据库连接成功") dbInstance = sqlDB }) if dbErr != nil { return nil, dbErr } return dbInstance, nil } func confirmAction(prompt string) bool { reader := bufio.NewReader(os.Stdin) logger.Warnf("%s [Y/Yes]: ", prompt) response, err := reader.ReadString('\n') if err != nil { return false } response = strings.ToLower(strings.TrimSpace(response)) return response == "y" || response == "yes" } func InitTables(db *sql.DB, force bool) error { if force { // 确认是否强制删除 if !confirmAction("确认强制删除所有表和触发器?") { logger.Info("操作已取消") db.Close() os.Exit(0) return nil } // 强制删除所有表和触发器 logger.Debug("强制删除所有表和触发器...") if err := dropTables(db); err != nil { return fmt.Errorf("删除表失败: %w", err) } logger.Debug("删除所有表和触发器成功") if err := dropTriggers(db); err != nil { return fmt.Errorf("删除触发器失败: %w", err) } logger.Debug("删除所有触发器成功") } // ✅ 调用 schema.go 中的函数 for _, ddl := range CreateTableStatements() { logger.Debugf("执行: %s", ddl) if _, err := db.Exec(ddl); err != nil { return fmt.Errorf("数据表创建失败: %w", err) } } logger.Info("数据库表创建成功") /* 使用sqlite3命令 测试数据库是否存在表 ✅ 查询所有表 sqlite3 /var/lib/sunhpc/sunhpc.db .tables # 查看所有表 select * from sqlite_master where type='table'; # 查看表定义 PRAGMA integrity_check; # 检查数据库完整性 */ return nil } func dropTables(db *sql.DB) error { // ✅ 调用 schema.go 中的函数 for _, table := range DropTableOrder() { if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil { return err } } return nil } func dropTriggers(db *sql.DB) error { // ✅ 调用 schema.go 中的函数 for _, trigger := range DropTriggerStatements() { if _, err := db.Exec(fmt.Sprintf("DROP TRIGGER IF EXISTS `%s`", trigger)); err != nil { return err } } return nil } func CloseDB() error { dbMutex.Lock() defer dbMutex.Unlock() if dbInstance == nil { if err := dbInstance.Close(); err != nil { return err } dbInstance = nil } return nil } // 使用事务回滚测试 func RunTestWithRollback(db *sql.DB, testFunc func(*sql.Tx) error) error { tx, err := db.Begin() if err != nil { return err } // 执行测试 if err := testFunc(tx); err != nil { tx.Rollback() return err } // 回滚事务,所有更改(包括 ID 递增)都会撤销 return tx.Rollback() } // 使用示例 func TestNodeInsert(db *sql.DB) error { logger.Debug("测试数据插入...") return RunTestWithRollback(db, func(tx *sql.Tx) error { // 插入测试数据 logger.Debug("执行插入测试数据...") _, err := tx.Exec(` INSERT INTO nodes (name, cpus, rack, rank) VALUES (?, ?, ?, ?) `, "test-node", 64, 1, 1) if err != nil { return err } // 验证插入 var count int logger.Debug("执行查询测试数据...") err = tx.QueryRow(` SELECT COUNT(*) FROM nodes WHERE name = ? `, "test-node").Scan(&count) if err != nil { return err } logger.Infof("测试数据插入成功,共 %d 条", count) // 不需要手动删除,回滚会自动撤销 return nil }) }