Files
sunhpc-go/pkg/wizard/config.go
2026-02-27 22:52:15 +08:00

186 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package wizard
import (
"bufio"
"database/sql"
"fmt"
"net"
"os"
"strings"
"sunhpc/pkg/database"
"sunhpc/pkg/utils"
)
// 配置项映射:定义每个配置项对应的表名、键名
var configMappings = []struct {
table string
key string
getVal func(m *model) interface{} // 动态获取配置值的函数
}{
// attributes 表
{"attributes", "license", func(m *model) any { return m.config.License }},
{"attributes", "accepted", func(m *model) any { return m.config.AgreementAccepted }},
{"attributes", "country", func(m *model) any { return m.config.Country }},
{"attributes", "region", func(m *model) any { return m.config.Region }},
{"attributes", "timezone", func(m *model) any { return m.config.Timezone }},
{"attributes", "homepage", func(m *model) any { return m.config.HomePage }},
{"attributes", "dbaddress", func(m *model) any { return m.config.DBAddress }},
{"attributes", "software", func(m *model) any { return m.config.Software }},
// nodes 表
{"nodes", "name", func(m *model) any { return m.config.Hostname }},
// 公网设置表
{"public_network", "public_interface", func(m *model) any { return m.config.PublicInterface }},
{"public_network", "ip_address", func(m *model) any { return m.config.PublicIPAddress }},
{"public_network", "netmask", func(m *model) any { return m.config.PublicNetmask }},
{"public_network", "gateway", func(m *model) any { return m.config.PublicGateway }},
// 内网配置表
{"internal_network", "internal_interface", func(m *model) any { return m.config.InternalInterface }},
{"internal_network", "internal_ip", func(m *model) any { return m.config.InternalIPAddress }},
{"internal_network", "internal_mask", func(m *model) any { return m.config.InternalNetmask }},
// DNS配置表
{"dns_config", "dns_primary", func(m *model) any { return m.config.DNSPrimary }},
{"dns_config", "dns_secondary", func(m *model) any { return m.config.DNSSecondary }},
}
// saveConfig 入口函数:保存所有配置到数据库
func (m *model) saveConfig() error {
conn, err := database.GetDB() // 假设database包已实现getDB()获取连接
if err != nil {
return fmt.Errorf("获取数据库连接失败: %w", err)
}
defer conn.Close()
m.force = false // 初始化全量覆盖标识
// 遍历所有配置项,逐个处理
for _, item := range configMappings {
val := item.getVal(m)
exists, err := m.checkExists(conn, item.table, item.key)
if err != nil {
return fmt.Errorf("检查%s.%s是否存在失败: %w", item.table, item.key, err)
}
// 根据存在性和用户选择处理
if !exists {
// 不存在则直接插入
if err := m.upsertConfig(conn, item.table, item.key, val, false); err != nil {
return fmt.Errorf("插入%s.%s失败: %w", item.table, item.key, err)
}
continue
}
// 已存在:判断是否全量覆盖
if m.force {
if err := m.upsertConfig(conn, item.table, item.key, val, true); err != nil {
return fmt.Errorf("强制更新%s.%s失败: %w", item.table, item.key, err)
}
continue
}
// 询问用户操作
choice, err := m.askUserChoice(item.table, item.key)
if err != nil {
return fmt.Errorf("获取用户选择失败: %w", err)
}
switch strings.ToLower(choice) {
case "y", "yes":
// 单条覆盖
if err := m.upsertConfig(conn, item.table, item.key, val, true); err != nil {
return fmt.Errorf("更新%s.%s失败: %w", item.table, item.key, err)
}
case "a", "all":
// 全量覆盖,后续不再询问
m.force = true
if err := m.upsertConfig(conn, item.table, item.key, val, true); err != nil {
return fmt.Errorf("全量更新%s.%s失败: %w", item.table, item.key, err)
}
case "n", "no":
// 跳过当前项
fmt.Printf("跳过%s.%s的更新\n", item.table, item.key)
default:
fmt.Printf("无效选择%s跳过%s.%s的更新\n", choice, item.table, item.key)
}
}
return nil
}
// checkExists 集中判断配置项是否存在(核心判断逻辑)
func (m *model) checkExists(conn *sql.DB, table, key string) (bool, error) {
var count int
// 通用存在性检查SQL假设所有表都有key字段作为主键
query := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE `key` = ?", table)
err := conn.QueryRow(query, key).Scan(&count)
if err != nil {
// 表不存在也视为"不存在"(可选:根据实际需求调整,比如先建表)
if strings.Contains(err.Error(), "table not found") {
return false, nil
}
return false, err
}
return count > 0, nil
}
// upsertConfig 统一处理插入/更新逻辑
func (m *model) upsertConfig(conn *sql.DB, table, key string, val interface{}, update bool) error {
var query string
if !update {
// 插入:假设表结构为(key, value)
query = fmt.Sprintf("INSERT INTO %s (`key`, `value`) VALUES (?, ?)", table)
} else {
// 更新
query = fmt.Sprintf("UPDATE %s SET `value` = ? WHERE `key` = ?", table)
}
// 处理参数顺序(更新和插入的参数顺序不同)
var args []interface{}
if !update {
args = []interface{}{key, val}
} else {
args = []interface{}{val, key}
}
_, err := conn.Exec(query, args...)
return err
}
// askUserChoice 询问用户操作选择
func (m *model) askUserChoice(table, key string) (string, error) {
reader := bufio.NewReader(os.Stdin)
fmt.Printf("配置项%s.%s已存在选择操作(y/yes=覆盖, n/no=跳过, a/all=全量覆盖后续所有): ", table, key)
input, err := reader.ReadString('\n')
if err != nil {
return "", err
}
// 去除空格和换行
return strings.TrimSpace(input), nil
}
// 获取系统网络接口
func getNetworkInterfaces() []string {
// 实现获取系统网络接口的逻辑
// 例如:使用 net.Interface() 函数获取系统网络接口
// 返回一个字符串切片,包含系统网络接口的名称
interfaces, err := net.Interfaces()
if err != nil {
return []string{utils.NoAvailableNetworkInterfaces}
}
var result []string
for _, iface := range interfaces {
// 跳过 loopback 接口
if iface.Flags&net.FlagLoopback != 0 {
continue
}
result = append(result, iface.Name)
}
if len(result) == 0 {
return []string{utils.NoAvailableNetworkInterfaces}
}
return result
}