Files
sunhpc-go/internal/db/db.go
2026-02-15 07:18:14 +08:00

129 lines
2.8 KiB
Go

package db
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"sync"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"sunhpc/internal/config"
"sunhpc/internal/log"
)
// DB wraps the sql.DB connection pool.
type DB struct {
engine *sql.DB
}
// Engine returns the underlying *sql.DB.
func (d *DB) Engine() *sql.DB {
return d.engine
}
// InitSchema initializes the database schema.
// If force is true, drops existing tables before recreating them.
func (d *DB) InitSchema(force bool) error {
db := d.engine
if force {
if err := dropTables(db); err != nil {
return fmt.Errorf("failed to drop tables: %w", err)
}
}
// ✅ 调用 schema.go 中的函数
for _, ddl := range CreateTableStatements() {
if _, err := db.Exec(ddl); err != nil {
return fmt.Errorf("failed to create table: %w", err)
}
}
return nil
}
func dropTables(db *sql.DB) error {
// ✅ 调用 schema.go 中的函数
for _, table := range DropTableOrder() {
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
return err
}
}
return nil
}
// --- Singleton DB Instance ---
var (
globalDB *DB
initOnce sync.Once
initErr error
)
func GetDB() (*DB, error) {
initOnce.Do(func() {
cfg, err := config.LoadConfig()
if err != nil {
initErr = fmt.Errorf("数据库配置文件加载失败: %w", err)
return
}
if _, err := os.Stat(cfg.DB.Path); err != nil {
// 创建数据库目录
if err := os.MkdirAll(cfg.DB.Path, 0755); err != nil {
log.Fatalf("创建数据库目录失败: %v", err)
}
log.Infof("数据库目录创建成功: %s", cfg.DB.Path)
}
var dsn string
var driver string
switch cfg.DB.Type {
case "sqlite":
driver = "sqlite3"
fullPath := filepath.Join(cfg.DB.Path, cfg.DB.Name)
dsn = fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL&_timeout=5000", fullPath)
case "mysql":
driver = "mysql"
if cfg.DB.Socket != "" {
dsn = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=true&loc=Local",
cfg.DB.User, cfg.DB.Password, cfg.DB.Socket, cfg.DB.Name)
} else {
dsn = fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true&loc=Local",
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Name)
}
default:
initErr = fmt.Errorf("unsupported database type: %s", cfg.DB.Type)
return
}
engine, err := sql.Open(driver, dsn)
if err != nil {
initErr = fmt.Errorf("failed to open database: %w", err)
return
}
if err := engine.Ping(); err != nil {
engine.Close()
initErr = fmt.Errorf("failed to ping database: %w", err)
return
}
globalDB = &DB{engine: engine}
})
return globalDB, initErr
}
// MustGetDB is a helper that panics on error (use in main/init only).
func MustGetDB() *DB {
db, err := GetDB()
if err != nil {
log.Fatalf("数据库初始化失败: %v", err)
}
return db
}