package database import ( "database/sql" "fmt" "strings" "sync" "sunhpc/pkg/config" "sunhpc/pkg/logger" _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" ) // ========================================================= // 全局变量 // ========================================================= var ( dbInstance *sql.DB dbOnce sync.Once dbMutex sync.RWMutex dbErr error ) // 封装数据库函数使用Go实现 // MapCategory - 根据类别名称查ID // 查询方式: globalID, err := db.MapCategory(conn, "global") func MapCategory(conn *sql.DB, catname string) (int, error) { var id int query := "select id from categories where name = ?" fullSQL := ReplaceSQLQuery(query, catname) err := conn.QueryRow(query, catname).Scan(&id) if err == sql.ErrNoRows { logger.Debugf("未找到类别 %s, 返回ID=0", catname) return 0, nil // 无匹配返回0 } logger.Debugf("查询语句: %s , CatName=%s, ID=%d", fullSQL, catname, id) return id, nil } // MapCategoryIndex - 根据类别名称 + 索引名称查ID // 调用方式: linuxOSID, err := db.MapCategoryIndex(conn, "os", "linux") func MapCategoryIndex(conn *sql.DB, catindexName, categoryIndex string) (int, error) { var id int query := ` select index_id from vmapCategoryIndex where categoryName = ? and categoryIndex = ? ` fullSQL := ReplaceSQLQuery(query, catindexName, categoryIndex) err := conn.QueryRow(query, catindexName, categoryIndex).Scan(&id) if err == sql.ErrNoRows { logger.Debugf("未找到索引 %s, 返回ID=0", catindexName) return 0, nil // 无匹配返回0 } logger.Debugf("查询语句: %s , CatIndexName=%s, CategoryIndex=%s, ID=%d", fullSQL, catindexName, categoryIndex, id) return id, nil } // ResolveFirewalls - 解析指定主机的防火墙规则 // 返回解析后的防火墙规则(fwresolved表数据),临时表使用后自动清理 // 调用方式: rows, err := db.ResolveFirewalls(conn, "compute-0-1", "default") func ResolveFirewalls(conn *sql.DB, hostname, chainname string) (*sql.Rows, error) { // 步骤1: 创建临时表 fresolved1 _, err := conn.Exec(` DROP TABLE IF EXISTS fresolved1; CREATE TEMPORARY TABLE fresolved1 AS SELECT ? AS hostname, ? AS Resolver, f.*, r.precedence FROM resolvechain r inner join hostselections hs on r.category = hs.category and r.name = ? inner join firewalls f on hs.category = f.category and hs.selection = f.catindex where hs.host = ?; `, hostname, chainname, chainname, hostname) if err != nil { return nil, fmt.Errorf("Create temporary table fresolved1 failed: %w", err) } // 步骤2:创建临时表 fresolved2 _, err = conn.Exec(` DROP TABLE IF EXISTS fresolved2; CREATE TEMPORARY TABLE fresolved2 AS SELECT * FROM fresolved1; `) if err != nil { return nil, fmt.Errorf("Create temporary table fresolved2 failed: %w", err) } // 步骤3:创建最终结果表 fwresolved _, err = conn.Exec(` DROP TABLE IF EXISTS fwresolved; CREATE TEMPORARY TABLE fwresolved AS SELECT r1.*, cat.name AS categoryName FROM fresolved1 r1 inner join ( select Rulename, MAX(precedence) as precedence from fresolved2 group by Rulename ) AS r2 on r1.Rulename = r2.Rulename and r1.precedence = r2.precedence inner join categories cat on r1.category = cat.id; `) if err != nil { return nil, fmt.Errorf("Create temporary table fwresolved failed: %w", err) } // 步骤4:查询结果并返回 rows, err := conn.Query("SELECT * FROM fwresolved") if err != nil { return nil, fmt.Errorf("Query fwresolved failed: %w", err) } return rows, nil } // ========================================================= // GetDB - 获取数据库连接(单例模式) // ========================================================= func GetDB() (*sql.DB, error) { logger.Debug("获取数据库连接...") dbOnce.Do(func() { if dbInstance != nil { return } // 确保配置已加载 cfg, err := config.LoadConfig() if err != nil { dbErr = fmt.Errorf("加载配置失败: %w", err) return } // 构建DSN logger.Debugf("DSN: %s", cfg.Database.DSN) // 打开SQLite 连接 sqlDB, err := sql.Open("sqlite3", cfg.Database.DSN) if err != nil { dbErr = fmt.Errorf("数据库打开失败: %w", err) return } // 设置连接池参数 sqlDB.SetMaxOpenConns(10) // 最大打开连接数 sqlDB.SetMaxIdleConns(5) // 保持空闲连接 sqlDB.SetConnMaxLifetime(0) // 禁用连接生命周期超时 sqlDB.SetConnMaxIdleTime(0) // 禁用空闲连接超时 // 测试数据库连接 if err := sqlDB.Ping(); err != nil { sqlDB.Close() dbErr = fmt.Errorf("数据库连接失败: %w", err) return } var version string err = sqlDB.QueryRow("select sqlite_version()").Scan(&version) if err != nil { version = "unknown" } logger.Debugf("数据库版本: %s", version) logger.Debug("数据库连接成功") dbInstance = sqlDB }) if dbErr != nil { return nil, dbErr } return dbInstance, nil } func InitTables(db *sql.DB, force bool) error { // 临时关闭外键约束(解决外键依赖删除报错问题) _, err := db.Exec("PRAGMA foreign_keys = OFF;") if err != nil { logger.Errorf("关闭外键约束失败: %v", err) return err } defer func() { // 延迟恢复外键约束(确保在函数退出时恢复) _, err := db.Exec("PRAGMA foreign_keys = ON;") if err != nil { logger.Errorf("恢复外键约束失败: %v", err) } }() // ✅ 调用 schema.go 中的函数 for name, ddl := range BaseTables() { // 删除表或者试图(如果存在) logger.Debugf("执行删除 - %s", name) // 先尝试作为表进行删除 query := fmt.Sprintf("DROP TABLE IF EXISTS %s;", name) logger.Debugf("执行语句: %s", query) _, err := db.Exec(query) if err != nil { // 如果作为表删除失败,尝试作为试图删除 logger.Debugf("删除表失败: %v", err) query = fmt.Sprintf("DROP VIEW IF EXISTS %s;", name) logger.Debugf("执行语句: %s", query) _, err = db.Exec(query) if err != nil { return fmt.Errorf("删除失败: %w", err) } } logger.Debugf("执行图表 - %s", name) logger.Debugf("执行语句: %s", ddl) if _, err := db.Exec(ddl); err != nil { return fmt.Errorf("数据表创建失败: %w", err) } } logger.Info("数据库表创建成功") /* 使用sqlite3命令 测试数据库是否存在表 ✅ 查询所有表 sqlite3 /var/lib/sunhpc/sunhpc.db .tables # 查看所有表 select * from sqlite_master where type='table'; # 查看表定义 PRAGMA integrity_check; # 检查数据库完整性 */ // 添加基础数据 if err := InitBaseData(db); err != nil { return fmt.Errorf("初始化基础数据失败: %w", err) } logger.Info("基础数据初始化成功") return nil } func CloseDB() error { dbMutex.Lock() defer dbMutex.Unlock() if dbInstance == nil { if err := dbInstance.Close(); err != nil { return err } dbInstance = nil } return nil } // 使用事务回滚测试 func RunTestWithRollback(db *sql.DB, testFunc func(*sql.Tx) error) error { tx, err := db.Begin() if err != nil { return err } // 执行测试 if err := testFunc(tx); err != nil { tx.Rollback() return err } // 回滚事务,所有更改(包括 ID 递增)都会撤销 return tx.Rollback() } // 使用示例 func TestNodeInsert(db *sql.DB) error { logger.Debug("测试数据插入...") return RunTestWithRollback(db, func(tx *sql.Tx) error { // 插入测试数据 logger.Debug("执行插入测试数据...") _, err := tx.Exec(` INSERT INTO nodes (name, cpus, rack, rank) VALUES (?, ?, ?, ?) `, "test-node", 64, 1, 1) if err != nil { return err } // 验证插入 var count int logger.Debug("执行查询测试数据...") err = tx.QueryRow(` SELECT COUNT(*) FROM nodes WHERE name = ? `, "test-node").Scan(&count) if err != nil { return err } logger.Infof("测试数据插入成功,共 %d 条", count) // 不需要手动删除,回滚会自动撤销 return nil }) } // ========================================================= // 带事务执行 SQL 语句,自动提交/回滚 // ========================================================= // 执行单条SQL语句,带事务管理 func ExecSingleWithTransaction(sqlStr string) error { // 复用批量函数,将单条SQL语句包装为数组执行 return ExecWithTransaction([]string{sqlStr}) } // 批量执行 DDL 语句,带事务管理 func ExecWithTransaction(ddl []string) error { conn, err := GetDB() if err != nil { logger.Errorf("获取数据库连接失败: %v", err) return err } // 开始事务 tx, err := conn.Begin() if err != nil { logger.Errorf("开始事务失败: %v", err) return err } var finished bool // 延迟处理:如果函数异常,回滚事务 defer func() { if r := recover(); r != nil { if !finished { // 捕获 panic 并回滚事务 tx.Rollback() logger.Errorf("事务执行中发生 panic: %v", r) } panic(r) } }() // 遍历执行 DDL 语句 for idx, sql := range ddl { logger.Debugf("执行 DDL 语句 %d: %s", idx+1, sql) _, err = tx.Exec(sql) if err != nil { // 执行失败时,回滚事务 rollbackErr := tx.Rollback() finished = true // 标记事务已完成 if rollbackErr != nil { logger.Errorf("执行失败: 回滚失败: %v (原错误: %v, SQL: %s)", rollbackErr, err, sql) } else { logger.Errorf("执行失败: 回滚事务: %v, SQL: %s", err, sql) } logger.Errorf("执行 %d 条, 失败: %w (SQL: %s)", idx+1, err, sql) return fmt.Errorf("执行 %d 条, 失败: %w (SQL: %s)", idx+1, err, sql) } } // 所有SQL语句执行成功,提交事务 logger.Info("所有SQL语句执行成功,提交事务") if err := tx.Commit(); err != nil { logger.Errorf("提交事务失败: %w", err) return err } finished = true // 标记事务已完成 logger.Debugf("成功执行 %d 条 SQL 语句, 事务已提交.", len(ddl)) return nil } func ReplaceSQLQuery(query string, args ...interface{}) string { for _, arg := range args { switch v := arg.(type) { case string: query = strings.Replace(query, "?", fmt.Sprintf("'%s'", v), 1) case int, int64, float64: query = strings.Replace(query, "?", fmt.Sprintf("%v", v), 1) default: query = strings.Replace(query, "?", fmt.Sprintf("%v", v), 1) } } return strings.TrimSpace(strings.ReplaceAll(query, "\n", " ")) }