package db import ( "bufio" "database/sql" "fmt" "os" "path/filepath" "strings" "sync" _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" "sunhpc/internal/config" "sunhpc/internal/log" ) // DB wraps the sql.DB connection pool. type DB struct { engine *sql.DB config *config.DBConfig // 保存配置 } /* // Engine returns the underlying *sql.DB. func (d *DB) Engine() *sql.DB { return d.engine } */ func ConfirmWithRetry(prompt string, maxAttempts int) bool { reader := bufio.NewReader(os.Stdin) for attempt := 1; attempt <= maxAttempts; attempt++ { log.Infof("%s [y/n]", prompt) response, err := reader.ReadString('\n') if err != nil { continue } response = strings.ToLower(strings.TrimSpace(response)) switch response { case "y", "yes": return true case "n", "no", "": return false default: if attempt < maxAttempts { log.Warnf( "⚠️ 无效输入、请输入 'y' 或 'n'(剩余尝试次数: %d)", maxAttempts-attempt) } } } log.Warn("⚠️ 警告:尝试次数过多、操作已取消") return false } // InitSchema initializes the database schema. // If force is true, drops existing tables before recreating them. func (d *DB) InitSchema(force bool) error { fullPath := filepath.Join(d.config.Path, d.config.Name) // 检查文件是否存在 _, err := os.Stat(fullPath) fileExists := err == nil // 处理不同的场景 switch { case !fileExists: // 场景1:文件不存在,连接并创建(allowCreate = true). log.Infof("数据库文件不存在,将创建: %s", fullPath) if err := d.Connect(true); err != nil { return err } return createTables(d.engine) case fileExists && !force: // 场景2:文件存在、无 force 参数、提示友好退出. log.Warnf("数据库文件已存在: %s", fullPath) log.Warn("如果需要强制重新初始化,请添加 --force 参数") log.Warn("数据库已存在、退出初始化操作.") os.Exit(1) case fileExists && force: // 场景3:文件存在、force 参数 -> 需要用户确认并重建. log.Warn("警告:强制重新初始化将清空数据库中的所有数据!") if !ConfirmWithRetry("是否继续?", 3) { return fmt.Errorf("用户取消操作") } // 连接现有数据库(allowCreate = true, 因为文件已经存在) if err := d.Connect(true); err != nil { return err } // 清空现有数据. if err := dropTables(d.engine); err != nil { return err } // 清空表 if err := dropTriggers(d.engine); err != nil { return err } log.Info("已清空现有--数据库触发器") return createTables(d.engine) } return nil } // 辅助函数:检查文件是否存在 func fileExists(path string) bool { _, err := os.Stat(path) if err == nil { return true } if os.IsNotExist(err) { return false } // 其他问题(如权限问题)也视为文件不存在,但应该记录日志 log.Debugf("检查数据库文件状态失败: %v", err) return false } // 辅助函数: 创建数据库表 func createTables(db *sql.DB) error { // ✅ 调用 schema.go 中的函数 for _, ddl := range CreateTableStatements() { log.Debugf("执行: %s", ddl) if _, err := db.Exec(ddl); err != nil { return fmt.Errorf("数据表创建失败: %w", err) } } log.Info("数据库表创建成功") /* 使用sqlite3命令 测试数据库是否存在表 ✅ 查询所有表 sqlite3 /var/lib/sunhpc/sunhpc.db .tables # 查看所有表 select * from sqlite_master where type='table'; # 查看表定义 PRAGMA integrity_check; # 检查数据库完整性 */ return nil } func dropTables(db *sql.DB) error { // ✅ 调用 schema.go 中的函数 for _, table := range DropTableOrder() { if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil { return err } } return nil } func dropTriggers(db *sql.DB) error { // ✅ 调用 schema.go 中的函数 for _, trigger := range DropTriggerStatements() { if _, err := db.Exec(fmt.Sprintf("DROP TRIGGER IF EXISTS `%s`", trigger)); err != nil { return err } } return nil } // --- Singleton DB Instance --- var ( globalDB *DB initOnce sync.Once initErr error ) func GetDB() (*DB, error) { initOnce.Do(func() { cfg, err := config.LoadConfig() if err != nil { initErr = fmt.Errorf("数据库配置文件加载失败: %w", err) return } globalDB = &DB{ config: &cfg.DB, } }) return globalDB, initErr } func (d *DB) Connect(allowCreate bool) error { // 如果已经连接,直接返回 if d.engine != nil { return nil } switch d.config.Type { case "sqlite": fullPath := filepath.Join(d.config.Path, d.config.Name) // 检查文件是否存在 _, err := os.Stat(fullPath) fileExists := err == nil // 如果文件不存在且不允许创建,返回错误 if !fileExists && !allowCreate { return fmt.Errorf("数据库文件不存在: %s, 请先初始化.", fullPath) } // 确保目录存在 if err := os.MkdirAll(d.config.Path, 0755); err != nil { return fmt.Errorf("创建数据库目录失败: %w", err) } // 连接参数, 开启外键约束(PRAGMA foreign_keys = ON)、WAL 模式、5秒超时 dsn := fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL&_timeout=5000", fullPath) engine, err := sql.Open("sqlite3", dsn) if err != nil { return fmt.Errorf("数据库打开失败: %w", err) } if err := engine.Ping(); err != nil { engine.Close() return fmt.Errorf("数据库连接失败: %w", err) } d.engine = engine case "mysql": // TODO: 实现 MySQL 连接逻辑 return fmt.Errorf("mysql 数据库连接未实现") } return nil } // Close 关闭数据库连接 func (d *DB) Close() error { if d.engine != nil { return d.engine.Close() } return nil } // GetEngine 获取数据库引擎(自动连接) func (d *DB) GetEngine() (*sql.DB, error) { // 如果还没有连接,自动连接(但不创建新文件) if d.engine == nil { if err := d.Connect(false); err != nil { return nil, err } } return d.engine, nil } // MustGetDB is a helper that panics on error (use in main/init only). func MustGetDB() *DB { db, err := GetDB() if err != nil { log.Fatalf("数据库初始化失败: %v", err) } return db } func GetDBConfig() (*config.DBConfig, error) { cfg, err := config.LoadConfig() if err != nil { return nil, fmt.Errorf("数据库配置文件加载失败: %w", err) } return &cfg.DB, nil } func CheckDB() (*config.Config, error) { cfg, err := config.LoadConfig() if err != nil { log.Warnf("加载配置失败: %v", err) } // 统一转为小写,避免用户输入错误 dbType := strings.ToLower(cfg.DB.Type) // 打印配置(调试用) log.Debugf("数据库类型: %s", dbType) log.Debugf("数据库名称: %s", cfg.DB.Name) log.Debugf("数据库路径: %s", cfg.DB.Path) log.Debugf("数据库用户: %s", cfg.DB.User) log.Debugf("数据库主机: %s", cfg.DB.Host) log.Debugf("数据库套接字: %s", cfg.DB.Socket) log.Debugf("数据库详细日志: %v", cfg.DB.Verbose) // 支持 sqlite,mysql的常见别名 isSQLite := dbType == "sqlite" || dbType == "sqlite3" isMySQL := dbType == "mysql" // 检查数据库类型,只允许 sqlite 和 mysql if !isSQLite && !isMySQL { log.Fatalf("不支持的数据库类型: %s(仅支持 sqlite、sqlite3、mysql)", dbType) } // 检查数据库路径是否存在 if isSQLite { if _, err := os.Stat(cfg.DB.Path); os.IsNotExist(err) { log.Warnf("SQLite 数据库路径 %s 不存在", cfg.DB.Path) log.Warn("必须先执行 'sunhpc init database' 初始化数据库") } } return cfg, nil }