Files
sunhpc-go/internal/db/db.go
2026-02-18 17:09:52 +08:00

307 lines
7.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package db
import (
"bufio"
"database/sql"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"sunhpc/internal/config"
"sunhpc/internal/log"
)
// DB wraps the sql.DB connection pool.
type DB struct {
engine *sql.DB
config *config.DBConfig // 保存配置
}
/*
// Engine returns the underlying *sql.DB.
func (d *DB) Engine() *sql.DB {
return d.engine
}
*/
func ConfirmWithRetry(prompt string, maxAttempts int) bool {
reader := bufio.NewReader(os.Stdin)
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Infof("%s [y/n]", prompt)
response, err := reader.ReadString('\n')
if err != nil {
continue
}
response = strings.ToLower(strings.TrimSpace(response))
switch response {
case "y", "yes":
return true
case "n", "no", "":
return false
default:
if attempt < maxAttempts {
log.Warnf(
"⚠️ 无效输入、请输入 'y' 或 'n'(剩余尝试次数: %d)",
maxAttempts-attempt)
}
}
}
log.Warn("⚠️ 警告:尝试次数过多、操作已取消")
return false
}
// InitSchema initializes the database schema.
// If force is true, drops existing tables before recreating them.
func (d *DB) InitSchema(force bool) error {
fullPath := filepath.Join(d.config.Path, d.config.Name)
// 检查文件是否存在
_, err := os.Stat(fullPath)
fileExists := err == nil
// 处理不同的场景
switch {
case !fileExists:
// 场景1文件不存在连接并创建(allowCreate = true).
log.Infof("数据库文件不存在,将创建: %s", fullPath)
if err := d.Connect(true); err != nil {
return err
}
return createTables(d.engine)
case fileExists && !force:
// 场景2文件存在、无 force 参数、提示友好退出.
log.Warnf("数据库文件已存在: %s", fullPath)
log.Warn("如果需要强制重新初始化,请添加 --force 参数")
log.Warn("数据库已存在、退出初始化操作.")
os.Exit(1)
case fileExists && force:
// 场景3文件存在、force 参数 -> 需要用户确认并重建.
log.Warn("警告:强制重新初始化将清空数据库中的所有数据!")
if !ConfirmWithRetry("是否继续?", 3) {
return fmt.Errorf("用户取消操作")
}
// 连接现有数据库(allowCreate = true, 因为文件已经存在)
if err := d.Connect(true); err != nil {
return err
}
// 清空现有数据.
if err := dropTables(d.engine); err != nil {
return err
}
// 清空表
if err := dropTriggers(d.engine); err != nil {
return err
}
log.Info("已清空现有数据库触发器")
return createTables(d.engine)
}
log.Info("数据库创建成功")
return nil
}
// 辅助函数:检查文件是否存在
func fileExists(path string) bool {
_, err := os.Stat(path)
if err == nil {
return true
}
if os.IsNotExist(err) {
return false
}
// 其他问题(如权限问题)也视为文件不存在,但应该记录日志
log.Debugf("检查数据库文件状态失败: %v", err)
return false
}
// 辅助函数: 创建数据库表
func createTables(db *sql.DB) error {
// ✅ 调用 schema.go 中的函数
for _, ddl := range CreateTableStatements() {
log.Debugf("执行: %s", ddl)
if _, err := db.Exec(ddl); err != nil {
return fmt.Errorf("数据表创建失败: %w", err)
}
}
log.Info("数据库表创建成功")
return nil
}
func dropTables(db *sql.DB) error {
// ✅ 调用 schema.go 中的函数
for _, table := range DropTableOrder() {
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
return err
}
}
return nil
}
func dropTriggers(db *sql.DB) error {
// ✅ 调用 schema.go 中的函数
for _, trigger := range DropTriggerStatements() {
if _, err := db.Exec(fmt.Sprintf("DROP TRIGGER IF EXISTS `%s`", trigger)); err != nil {
return err
}
}
return nil
}
// --- Singleton DB Instance ---
var (
globalDB *DB
initOnce sync.Once
initErr error
)
func GetDB() (*DB, error) {
initOnce.Do(func() {
cfg, err := config.LoadConfig()
if err != nil {
initErr = fmt.Errorf("数据库配置文件加载失败: %w", err)
return
}
globalDB = &DB{
config: &cfg.DB,
}
})
return globalDB, initErr
}
func (d *DB) Connect(allowCreate bool) error {
// 如果已经连接,直接返回
if d.engine != nil {
return nil
}
switch d.config.Type {
case "sqlite":
fullPath := filepath.Join(d.config.Path, d.config.Name)
// 检查文件是否存在
_, err := os.Stat(fullPath)
fileExists := err == nil
// 如果文件不存在且不允许创建,返回错误
if !fileExists && !allowCreate {
return fmt.Errorf("数据库文件不存在: %s, 请先初始化.", fullPath)
}
// 确保目录存在
if err := os.MkdirAll(d.config.Path, 0755); err != nil {
return fmt.Errorf("创建数据库目录失败: %w", err)
}
// 连接参数
dsn := fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL&_timeout=5000",
fullPath)
engine, err := sql.Open("sqlite3", dsn)
if err != nil {
return fmt.Errorf("数据库打开失败: %w", err)
}
if err := engine.Ping(); err != nil {
engine.Close()
return fmt.Errorf("数据库连接失败: %w", err)
}
d.engine = engine
case "mysql":
// TODO: 实现 MySQL 连接逻辑
return fmt.Errorf("mysql 数据库连接未实现")
}
return nil
}
// Close 关闭数据库连接
func (d *DB) Close() error {
if d.engine != nil {
return d.engine.Close()
}
return nil
}
// GetEngine 获取数据库引擎(自动连接)
func (d *DB) GetEngine() (*sql.DB, error) {
// 如果还没有连接,自动连接(但不创建新文件)
if d.engine == nil {
if err := d.Connect(false); err != nil {
return nil, err
}
}
return d.engine, nil
}
// MustGetDB is a helper that panics on error (use in main/init only).
func MustGetDB() *DB {
db, err := GetDB()
if err != nil {
log.Fatalf("数据库初始化失败: %v", err)
}
return db
}
func GetDBConfig() (*config.DBConfig, error) {
cfg, err := config.LoadConfig()
if err != nil {
return nil, fmt.Errorf("数据库配置文件加载失败: %w", err)
}
return &cfg.DB, nil
}
func CheckDB() (*config.Config, error) {
cfg, err := config.LoadConfig()
if err != nil {
log.Warnf("加载配置失败: %v", err)
}
// 统一转为小写,避免用户输入错误
dbType := strings.ToLower(cfg.DB.Type)
// 打印配置(调试用)
log.Debugf("数据库类型: %s", dbType)
log.Debugf("数据库名称: %s", cfg.DB.Name)
log.Debugf("数据库路径: %s", cfg.DB.Path)
log.Debugf("数据库用户: %s", cfg.DB.User)
log.Debugf("数据库主机: %s", cfg.DB.Host)
log.Debugf("数据库套接字: %s", cfg.DB.Socket)
log.Debugf("数据库详细日志: %v", cfg.DB.Verbose)
// 支持 sqlitemysql的常见别名
isSQLite := dbType == "sqlite" || dbType == "sqlite3"
isMySQL := dbType == "mysql"
// 检查数据库类型,只允许 sqlite 和 mysql
if !isSQLite && !isMySQL {
log.Fatalf("不支持的数据库类型: %s(仅支持 sqlite、sqlite3、mysql)", dbType)
}
// 检查数据库路径是否存在
if isSQLite {
if _, err := os.Stat(cfg.DB.Path); os.IsNotExist(err) {
log.Warnf("SQLite 数据库路径 %s 不存在", cfg.DB.Path)
log.Warn("必须先执行 'sunhpc init database' 初始化数据库")
}
}
return cfg, nil
}