package db import ( "database/sql" "fmt" "os" "path/filepath" "sync" _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" "sunhpc/internal/config" "sunhpc/internal/log" ) // DB wraps the sql.DB connection pool. type DB struct { engine *sql.DB } // Engine returns the underlying *sql.DB. func (d *DB) Engine() *sql.DB { return d.engine } // InitSchema initializes the database schema. // If force is true, drops existing tables before recreating them. func (d *DB) InitSchema(force bool) error { db := d.engine if force { if err := dropTables(db); err != nil { return fmt.Errorf("failed to drop tables: %w", err) } } // ✅ 调用 schema.go 中的函数 for _, ddl := range CreateTableStatements() { if _, err := db.Exec(ddl); err != nil { return fmt.Errorf("failed to create table: %w", err) } } return nil } func dropTables(db *sql.DB) error { // ✅ 调用 schema.go 中的函数 for _, table := range DropTableOrder() { if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil { return err } } return nil } // --- Singleton DB Instance --- var ( globalDB *DB initOnce sync.Once initErr error ) func GetDB() (*DB, error) { initOnce.Do(func() { cfg, err := config.LoadConfig() if err != nil { initErr = fmt.Errorf("数据库配置文件加载失败: %w", err) return } if _, err := os.Stat(cfg.DB.Path); err != nil { // 创建数据库目录 if err := os.MkdirAll(cfg.DB.Path, 0755); err != nil { log.Fatalf("创建数据库目录失败: %v", err) } log.Infof("数据库目录创建成功: %s", cfg.DB.Path) } var dsn string var driver string switch cfg.DB.Type { case "sqlite": driver = "sqlite3" fullPath := filepath.Join(cfg.DB.Path, cfg.DB.Name) dsn = fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL&_timeout=5000", fullPath) case "mysql": driver = "mysql" if cfg.DB.Socket != "" { dsn = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=true&loc=Local", cfg.DB.User, cfg.DB.Password, cfg.DB.Socket, cfg.DB.Name) } else { dsn = fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true&loc=Local", cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Name) } default: initErr = fmt.Errorf("unsupported database type: %s", cfg.DB.Type) return } engine, err := sql.Open(driver, dsn) if err != nil { initErr = fmt.Errorf("failed to open database: %w", err) return } if err := engine.Ping(); err != nil { engine.Close() initErr = fmt.Errorf("failed to ping database: %w", err) return } globalDB = &DB{engine: engine} }) return globalDB, initErr } // MustGetDB is a helper that panics on error (use in main/init only). func MustGetDB() *DB { db, err := GetDB() if err != nil { log.Fatalf("数据库初始化失败: %v", err) } return db }