129 lines
2.8 KiB
Go
129 lines
2.8 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
|
|
"sunhpc/internal/config"
|
|
"sunhpc/internal/log"
|
|
)
|
|
|
|
// DB wraps the sql.DB connection pool.
|
|
type DB struct {
|
|
engine *sql.DB
|
|
}
|
|
|
|
// Engine returns the underlying *sql.DB.
|
|
func (d *DB) Engine() *sql.DB {
|
|
return d.engine
|
|
}
|
|
|
|
// InitSchema initializes the database schema.
|
|
// If force is true, drops existing tables before recreating them.
|
|
func (d *DB) InitSchema(force bool) error {
|
|
db := d.engine
|
|
|
|
if force {
|
|
if err := dropTables(db); err != nil {
|
|
return fmt.Errorf("failed to drop tables: %w", err)
|
|
}
|
|
}
|
|
|
|
// ✅ 调用 schema.go 中的函数
|
|
for _, ddl := range CreateTableStatements() {
|
|
if _, err := db.Exec(ddl); err != nil {
|
|
return fmt.Errorf("failed to create table: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func dropTables(db *sql.DB) error {
|
|
// ✅ 调用 schema.go 中的函数
|
|
for _, table := range DropTableOrder() {
|
|
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// --- Singleton DB Instance ---
|
|
var (
|
|
globalDB *DB
|
|
initOnce sync.Once
|
|
initErr error
|
|
)
|
|
|
|
func GetDB() (*DB, error) {
|
|
initOnce.Do(func() {
|
|
cfg, err := config.LoadConfig()
|
|
if err != nil {
|
|
initErr = fmt.Errorf("数据库配置文件加载失败: %w", err)
|
|
return
|
|
}
|
|
|
|
if _, err := os.Stat(cfg.DB.Path); err != nil {
|
|
// 创建数据库目录
|
|
if err := os.MkdirAll(cfg.DB.Path, 0755); err != nil {
|
|
log.Fatalf("创建数据库目录失败: %v", err)
|
|
}
|
|
log.Infof("数据库目录创建成功: %s", cfg.DB.Path)
|
|
}
|
|
|
|
var dsn string
|
|
var driver string
|
|
|
|
switch cfg.DB.Type {
|
|
case "sqlite":
|
|
driver = "sqlite3"
|
|
fullPath := filepath.Join(cfg.DB.Path, cfg.DB.Name)
|
|
dsn = fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL&_timeout=5000", fullPath)
|
|
case "mysql":
|
|
driver = "mysql"
|
|
if cfg.DB.Socket != "" {
|
|
dsn = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=true&loc=Local",
|
|
cfg.DB.User, cfg.DB.Password, cfg.DB.Socket, cfg.DB.Name)
|
|
} else {
|
|
dsn = fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true&loc=Local",
|
|
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Name)
|
|
}
|
|
default:
|
|
initErr = fmt.Errorf("unsupported database type: %s", cfg.DB.Type)
|
|
return
|
|
}
|
|
|
|
engine, err := sql.Open(driver, dsn)
|
|
if err != nil {
|
|
initErr = fmt.Errorf("failed to open database: %w", err)
|
|
return
|
|
}
|
|
|
|
if err := engine.Ping(); err != nil {
|
|
engine.Close()
|
|
initErr = fmt.Errorf("failed to ping database: %w", err)
|
|
return
|
|
}
|
|
|
|
globalDB = &DB{engine: engine}
|
|
})
|
|
|
|
return globalDB, initErr
|
|
}
|
|
|
|
// MustGetDB is a helper that panics on error (use in main/init only).
|
|
func MustGetDB() *DB {
|
|
db, err := GetDB()
|
|
if err != nil {
|
|
log.Fatalf("数据库初始化失败: %v", err)
|
|
}
|
|
return db
|
|
}
|