package db import ( "context" "database/sql" "fmt" "io/ioutil" "os" "path/filepath" "strings" "sync" "time" "sunhpc/internal/log" _ "github.com/mattn/go-sqlite3" ) // 全局单例 var ( globalDB *DB once sync.Once ) // DB 核心数据库类 - 对应Rocks的Database类 type DB struct { // 连接参数 dbUser string dbPasswd string dbHost string dbName string dbPath string dbSocket string verbose bool forceInit bool // 连接对象 engine *sql.DB // 连接池 conn *sql.Conn // 当前连接 results *sql.Rows // 当前结果集 // 线程本地存储模拟 sessions sync.Map mu sync.RWMutex } // NewDB 创建新实例 func NewDB() *DB { return &DB{ dbUser: "", dbPasswd: "", dbHost: "localhost", dbName: "sunhpc", dbPath: "/var/lib/sunhpc", dbSocket: "/var/lib/sunhpc/mysql/mysql.sock", verbose: false, } } // ==================== 连接参数设置/获取 ==================== func (db *DB) SetDBPasswd(passwd string) { db.mu.Lock() defer db.mu.Unlock() db.dbPasswd = passwd } func (db *DB) GetDBPasswd() string { db.mu.RLock() if db.dbPasswd != "" { db.mu.RUnlock() return db.dbPasswd } db.mu.RUnlock() db.mu.Lock() defer db.mu.Unlock() // 从配置文件读取密码 username := db.GetDBUsername() var filename string switch username { case "root": filename = "/root/.sunhpc.my.cnf" default: filename = fmt.Sprintf("/home/%s/.sunhpc.my.cnf", username) } data, err := ioutil.ReadFile(filename) if err != nil { return "" } lines := strings.Split(string(data), "\n") for _, line := range lines { line = strings.TrimSpace(line) parts := strings.Split(line, "=") if len(parts) == 2 && strings.TrimSpace(parts[0]) == "password" { db.dbPasswd = strings.TrimSpace(parts[1]) break } } return db.dbPasswd } func (db *DB) SetDBUsername(name string) { db.mu.Lock() defer db.mu.Unlock() db.dbUser = name } func (db *DB) GetDBUsername() string { db.mu.RLock() if db.dbUser != "" { db.mu.RUnlock() return db.dbUser } db.mu.RUnlock() db.mu.Lock() defer db.mu.Unlock() db.dbUser = os.Getenv("USER") return db.dbUser } func (db *DB) SetDBHostname(host string) { db.mu.Lock() defer db.mu.Unlock() db.dbHost = host } func (db *DB) GetDBHostname() string { db.mu.RLock() defer db.mu.RUnlock() return db.dbHost } func (db *DB) SetDBName(name string) { db.mu.Lock() defer db.mu.Unlock() db.dbName = name } func (db *DB) GetDBName() string { db.mu.RLock() defer db.mu.RUnlock() return db.dbName } func (db *DB) SetDBPath(path string) { db.mu.Lock() defer db.mu.Unlock() db.dbPath = path } func (db *DB) GetDBPath() string { db.mu.RLock() defer db.mu.RUnlock() return db.dbPath } func (db *DB) SetVerbose(verbose bool) { db.mu.Lock() defer db.mu.Unlock() db.verbose = verbose } func (db *DB) SetForceInit(force bool) { db.mu.Lock() defer db.mu.Unlock() db.forceInit = force } // ==================== 连接管理 ==================== // Connect 连接数据库 func (db *DB) Connect() error { log.Debug("连接数据库...") db.mu.Lock() defer db.mu.Unlock() log.Debug("检查 SUNHPCDEBUG 环境变量...") if os.Getenv("SUNHPCDEBUG") != "" { db.verbose = true } // 使用SQLite dbFullPath := filepath.Join(db.dbPath, db.dbName+".db") log.Debugf("数据库路径: %s", dbFullPath) // 确保目录存在 log.Debug("确保数据库目录存在...") os.MkdirAll(db.dbPath, 0755) engine, err := sql.Open("sqlite3", dbFullPath+"?_foreign_keys=on&_journal_mode=WAL") log.Debugf("打开数据库连接...") if err != nil { return fmt.Errorf("打开数据库失败: %v", err) } engine.SetMaxOpenConns(10) engine.SetMaxIdleConns(5) engine.SetConnMaxLifetime(time.Hour) db.engine = engine conn, err := engine.Conn(context.Background()) log.Debugf("获取数据库连接...") if err != nil { return fmt.Errorf("获取连接失败: %v", err) } db.conn = conn // 初始化数据库表 if err := db.initSchema(); err != nil { return fmt.Errorf("初始化数据库表失败: %v", err) } if db.verbose { log.Infof("数据库连接成功: %s", dbFullPath) } return nil } // initSchema 初始化数据库表结构 - 所有表定义在这里 func (db *DB) initSchema() error { log.Debug("初始化数据库表结构...") // 检查 nodes 表是否已存在 var tableName string err := db.engine.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='nodes'").Scan(&tableName) if err == nil && !db.forceInit { log.Debug("数据库表已存在,跳过初始化") return nil } if db.forceInit { log.Warn("强制重新初始化数据库表结构...") } else { log.Info("首次初始化数据库表结构...") } // 如果强制初始化,先删除所有表 if db.forceInit { log.Info("删除现有表...") dropSQLs := []string{ `DROP TABLE IF EXISTS resolvechain;`, `DROP TABLE IF EXISTS hostselections;`, `DROP TABLE IF EXISTS attributes;`, `DROP TABLE IF EXISTS catindexes;`, `DROP TABLE IF EXISTS categories;`, `DROP TABLE IF EXISTS node_attrs;`, `DROP TABLE IF EXISTS aliases;`, `DROP TABLE IF EXISTS networks;`, `DROP TABLE IF EXISTS subnets;`, `DROP TABLE IF EXISTS software_installs;`, `DROP TABLE IF EXISTS memberships;`, `DROP TABLE IF EXISTS appliances;`, `DROP TABLE IF EXISTS nodes;`, } for _, sql := range dropSQLs { if _, err := db.engine.Exec(sql); err != nil { log.Warnf("删除表失败: %v", err) } } log.Info("现有表已删除") } // 开启事务 tx, err := db.engine.Begin() if err != nil { return fmt.Errorf("开启事务失败: %v", err) } // 使用exec执行,每条SQL单独执行 sqls := []string{ // 创建表 - 注意创建顺序(先创建主表,再创建有外键的表) `CREATE TABLE IF NOT EXISTS nodes ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE, rack INTEGER DEFAULT 0, rank INTEGER DEFAULT 0, membership_id INTEGER, cpus INTEGER DEFAULT 0, memory INTEGER DEFAULT 0, disk INTEGER DEFAULT 0, os TEXT, kernel TEXT, last_state_change DATETIME DEFAULT CURRENT_TIMESTAMP, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP );`, `CREATE TABLE IF NOT EXISTS appliances ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE, description TEXT, node_type TEXT DEFAULT 'compute' );`, `CREATE TABLE IF NOT EXISTS memberships ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE, appliance_id INTEGER, FOREIGN KEY (appliance_id) REFERENCES appliances(id) ON DELETE SET NULL );`, `CREATE TABLE IF NOT EXISTS subnets ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT UNIQUE, network TEXT, netmask TEXT, gateway TEXT, dns_zone TEXT, is_private INTEGER DEFAULT 1 );`, `CREATE TABLE IF NOT EXISTS networks ( id INTEGER PRIMARY KEY AUTOINCREMENT, node_id INTEGER NOT NULL, name TEXT, ip TEXT UNIQUE, mac TEXT UNIQUE, subnet_id INTEGER, interface TEXT DEFAULT 'eth0', FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, FOREIGN KEY (subnet_id) REFERENCES subnets(id) ON DELETE SET NULL );`, `CREATE TABLE IF NOT EXISTS aliases ( id INTEGER PRIMARY KEY AUTOINCREMENT, node_id INTEGER NOT NULL, name TEXT NOT NULL, FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, UNIQUE(node_id, name) );`, `CREATE TABLE IF NOT EXISTS categories ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE );`, `CREATE TABLE IF NOT EXISTS catindexes ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, category_id INTEGER NOT NULL, FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, UNIQUE(name, category_id) );`, `CREATE TABLE IF NOT EXISTS attributes ( id INTEGER PRIMARY KEY AUTOINCREMENT, attr TEXT NOT NULL, value TEXT, category_id INTEGER NOT NULL, catindex_id INTEGER NOT NULL, FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, FOREIGN KEY (catindex_id) REFERENCES catindexes(id) ON DELETE CASCADE, UNIQUE(attr, category_id, catindex_id) );`, `CREATE TABLE IF NOT EXISTS node_attrs ( id INTEGER PRIMARY KEY AUTOINCREMENT, node_id INTEGER NOT NULL, attr TEXT NOT NULL, value TEXT, FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, UNIQUE(node_id, attr) );`, `CREATE TABLE IF NOT EXISTS hostselections ( id INTEGER PRIMARY KEY AUTOINCREMENT, host_id INTEGER NOT NULL, category_id INTEGER NOT NULL, selection TEXT NOT NULL, FOREIGN KEY (host_id) REFERENCES nodes(id) ON DELETE CASCADE, FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, UNIQUE(host_id, category_id, selection) );`, `CREATE TABLE IF NOT EXISTS resolvechain ( id INTEGER PRIMARY KEY AUTOINCREMENT, category_id INTEGER NOT NULL, precedence INTEGER NOT NULL, FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, UNIQUE(category_id, precedence) );`, `CREATE TABLE IF NOT EXISTS software_installs ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, version TEXT, install_type TEXT, node_id INTEGER, status TEXT, installed_at DATETIME DEFAULT CURRENT_TIMESTAMP, installed_by TEXT, FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE SET NULL );`, // 创建索引 `CREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name);`, `CREATE INDEX IF NOT EXISTS idx_networks_ip ON networks(ip);`, `CREATE INDEX IF NOT EXISTS idx_networks_mac ON networks(mac);`, `CREATE INDEX IF NOT EXISTS idx_attributes_lookup ON attributes(attr, category_id, catindex_id);`, `CREATE INDEX IF NOT EXISTS idx_node_attrs_lookup ON node_attrs(node_id, attr);`, `CREATE INDEX IF NOT EXISTS idx_hostselections_host ON hostselections(host_id);`, `CREATE INDEX IF NOT EXISTS idx_resolvechain_precedence ON resolvechain(precedence);`, } // 逐条执行SQL for i, sql := range sqls { if strings.TrimSpace(sql) == "" { continue } log.Debugf("执行SQL[%d]: %s", i, strings.TrimSpace(strings.Split(sql, "\n")[0])) _, err := tx.Exec(sql) if err != nil { tx.Rollback() return fmt.Errorf("执行SQL[%d]失败: %v\nSQL: %s", i, err, sql) } } // 提交事务 if err := tx.Commit(); err != nil { return fmt.Errorf("提交事务失败: %v", err) } log.Info("数据库表结构创建成功") // 插入默认数据 return db.insertDefaultData() } // insertDefaultData 插入默认数据 func (db *DB) insertDefaultData() error { log.Debug("插入默认数据...") // 默认类别 categories := []string{"global", "host", "os", "appliance", "network"} for _, cat := range categories { _, err := db.engine.Exec( "INSERT OR IGNORE INTO categories (name) VALUES (?)", cat, ) if err != nil { return err } } log.Debug("插入默认类别索引...") // 默认类别索引 catIndexes := []struct { catName string idxName string }{ {"global", "global"}, {"os", "linux"}, {"network", "private"}, } for _, ci := range catIndexes { _, err := db.engine.Exec(` INSERT OR IGNORE INTO catindexes (name, category_id) SELECT ?, id FROM categories WHERE name = ? `, ci.idxName, ci.catName) if err != nil { return err } } log.Debug("插入默认解析链优先级...") // 默认解析链优先级 precedence := []struct { catName string level int }{ {"global", 1}, {"os", 2}, {"appliance", 3}, {"host", 4}, {"network", 5}, } for _, p := range precedence { _, err := db.engine.Exec(` INSERT OR IGNORE INTO resolvechain (category_id, precedence) SELECT id, ? FROM categories WHERE name = ? `, p.level, p.catName) if err != nil { return err } } log.Debug("插入默认设备类型...") // 默认设备类型 appliances := []struct { name string desc string typ string }{ {"frontend", "管理节点", "master"}, {"compute", "计算节点", "compute"}, {"login", "登录节点", "login"}, {"storage", "存储节点", "storage"}, } for _, a := range appliances { _, err := db.engine.Exec( "INSERT OR IGNORE INTO appliances (name, description, node_type) VALUES (?, ?, ?)", a.name, a.desc, a.typ, ) if err != nil { return err } } log.Debug("插入默认数据完成...") return nil } // ==================== 核心查询方法 ==================== // Execute 执行SQL语句 - 对应Rocks的execute() func (db *DB) Execute(query string, args ...interface{}) (int64, error) { db.mu.RLock() conn := db.conn verbose := db.verbose db.mu.RUnlock() if conn == nil { return 0, fmt.Errorf("没有活动数据库连接") } if verbose { log.Debugf("执行SQL: %s %v", query, args) } // 判断SQL类型 upperQuery := strings.ToUpper(strings.TrimSpace(query)) isSelect := strings.HasPrefix(upperQuery, "SELECT") if isSelect { // SELECT 查询使用 QueryContext rows, err := conn.QueryContext(context.Background(), query, args...) if err != nil { // 尝试重连一次 db.RenewConnection() db.mu.RLock() conn = db.conn db.mu.RUnlock() rows, err = conn.QueryContext(context.Background(), query, args...) } if err != nil { return 0, err } // 关闭旧结果 db.mu.Lock() if db.results != nil { db.results.Close() } db.results = rows db.mu.Unlock() return 0, nil } else { // INSERT/UPDATE/DELETE 使用 Exec(自动提交) result, err := conn.ExecContext(context.Background(), query, args...) if err != nil { // 尝试重连一次 db.RenewConnection() db.mu.RLock() conn = db.conn db.mu.RUnlock() result, err = conn.ExecContext(context.Background(), query, args...) } if err != nil { return 0, err } // 获取影响行数 rowsAffected, err := result.RowsAffected() if err != nil { return 0, err } if verbose { log.Debugf("影响行数: %d", rowsAffected) } return rowsAffected, nil } } // FetchOne 获取一行 - 对应Rocks的fetchone() // 返回map[string]interface{}格式,key为列名 func (db *DB) FetchOne() (map[string]interface{}, error) { db.mu.RLock() results := db.results db.mu.RUnlock() if results == nil { return nil, nil } if !results.Next() { return nil, nil } columns, err := results.Columns() if err != nil { return nil, err } values := make([]interface{}, len(columns)) scanArgs := make([]interface{}, len(columns)) for i := range values { scanArgs[i] = &values[i] } err = results.Scan(scanArgs...) if err != nil { return nil, err } row := make(map[string]interface{}) for i, col := range columns { val := values[i] if b, ok := val.([]byte); ok { row[col] = string(b) } else { row[col] = val } } return row, nil } // FetchAll 获取所有行 - 对应Rocks的fetchall() // 返回[]map[string]interface{}格式 func (db *DB) FetchAll() ([]map[string]interface{}, error) { db.mu.RLock() results := db.results db.mu.RUnlock() if results == nil { return nil, nil } columns, err := results.Columns() if err != nil { return nil, err } var rows []map[string]interface{} for results.Next() { values := make([]interface{}, len(columns)) scanArgs := make([]interface{}, len(columns)) for i := range values { scanArgs[i] = &values[i] } err = results.Scan(scanArgs...) if err != nil { return nil, err } row := make(map[string]interface{}) for i, col := range columns { val := values[i] if b, ok := val.([]byte); ok { row[col] = string(b) } else { row[col] = val } } rows = append(rows, row) } return rows, nil } // ==================== 连接维护 ==================== // RenewConnection 续期连接 func (db *DB) RenewConnection() error { db.mu.Lock() defer db.mu.Unlock() if db.conn != nil { db.conn.Close() } conn, err := db.engine.Conn(context.Background()) if err != nil { return err } db.conn = conn return nil } // Close 关闭连接 func (db *DB) Close() error { db.mu.Lock() defer db.mu.Unlock() if db.results != nil { db.results.Close() db.results = nil } if db.conn != nil { db.conn.Close() db.conn = nil } if db.engine != nil { return db.engine.Close() } return nil } // CloseConnection 只关闭当前连接,不关闭连接池 func (db *DB) CloseConnection() error { db.mu.Lock() defer db.mu.Unlock() if db.results != nil { db.results.Close() db.results = nil } if db.conn != nil { db.conn.Close() db.conn = nil } return nil } // ==================== 单例模式 ==================== var ( instanceConfigured bool instanceDBPath string instanceDBName string ) func GetInstance() (*DB, error) { return GetInstanceWithConfig("", "") } func GetInstanceWithConfig(dbPath, dbName string) (*DB, error) { var err error once.Do(func() { globalDB = NewDB() log.Debug("创建数据库实例...") globalDB.SetDBUsername(globalDB.GetDBUsername()) if dbPath != "" { globalDB.SetDBPath(dbPath) log.Debugf("设置数据库路径: %s", dbPath) } if dbName != "" { globalDB.SetDBName(dbName) log.Debugf("设置数据库名称: %s", dbName) } instanceConfigured = (dbPath != "" || dbName != "") if dbPath != "" { instanceDBPath = dbPath } if dbName != "" { instanceDBName = dbName } err = globalDB.Connect() }) return globalDB, err } func IsInstanceConfigured() bool { return instanceConfigured } func GetInstanceConfig() (dbPath, dbName string) { return instanceDBPath, instanceDBName }