package wizard import ( "bufio" "database/sql" "fmt" "net" "os" "strings" "sunhpc/pkg/database" "sunhpc/pkg/utils" ) // 配置项映射:定义每个配置项对应的表名、键名 var configMappings = []struct { table string key string getVal func(m *model) interface{} // 动态获取配置值的函数 }{ // attributes 表 {"attributes", "license", func(m *model) any { return m.config.License }}, {"attributes", "accepted", func(m *model) any { return m.config.AgreementAccepted }}, {"attributes", "country", func(m *model) any { return m.config.Country }}, {"attributes", "region", func(m *model) any { return m.config.Region }}, {"attributes", "timezone", func(m *model) any { return m.config.Timezone }}, {"attributes", "homepage", func(m *model) any { return m.config.HomePage }}, {"attributes", "dbaddress", func(m *model) any { return m.config.DBAddress }}, {"attributes", "software", func(m *model) any { return m.config.Software }}, // nodes 表 {"nodes", "name", func(m *model) any { return m.config.Hostname }}, // 公网设置表 {"public_network", "public_interface", func(m *model) any { return m.config.PublicInterface }}, {"public_network", "ip_address", func(m *model) any { return m.config.PublicIPAddress }}, {"public_network", "netmask", func(m *model) any { return m.config.PublicNetmask }}, {"public_network", "gateway", func(m *model) any { return m.config.PublicGateway }}, // 内网配置表 {"internal_network", "internal_interface", func(m *model) any { return m.config.InternalInterface }}, {"internal_network", "internal_ip", func(m *model) any { return m.config.InternalIPAddress }}, {"internal_network", "internal_mask", func(m *model) any { return m.config.InternalNetmask }}, // DNS配置表 {"dns_config", "dns_primary", func(m *model) any { return m.config.DNSPrimary }}, {"dns_config", "dns_secondary", func(m *model) any { return m.config.DNSSecondary }}, } // saveConfig 入口函数:保存所有配置到数据库 func (m *model) saveConfig() error { conn, err := database.GetDB() // 假设database包已实现getDB()获取连接 if err != nil { return fmt.Errorf("获取数据库连接失败: %w", err) } defer conn.Close() m.force = false // 初始化全量覆盖标识 // 遍历所有配置项,逐个处理 for _, item := range configMappings { val := item.getVal(m) exists, err := m.checkExists(conn, item.table, item.key) if err != nil { return fmt.Errorf("检查%s.%s是否存在失败: %w", item.table, item.key, err) } // 根据存在性和用户选择处理 if !exists { // 不存在则直接插入 if err := m.upsertConfig(conn, item.table, item.key, val, false); err != nil { return fmt.Errorf("插入%s.%s失败: %w", item.table, item.key, err) } continue } // 已存在:判断是否全量覆盖 if m.force { if err := m.upsertConfig(conn, item.table, item.key, val, true); err != nil { return fmt.Errorf("强制更新%s.%s失败: %w", item.table, item.key, err) } continue } // 询问用户操作 choice, err := m.askUserChoice(item.table, item.key) if err != nil { return fmt.Errorf("获取用户选择失败: %w", err) } switch strings.ToLower(choice) { case "y", "yes": // 单条覆盖 if err := m.upsertConfig(conn, item.table, item.key, val, true); err != nil { return fmt.Errorf("更新%s.%s失败: %w", item.table, item.key, err) } case "a", "all": // 全量覆盖,后续不再询问 m.force = true if err := m.upsertConfig(conn, item.table, item.key, val, true); err != nil { return fmt.Errorf("全量更新%s.%s失败: %w", item.table, item.key, err) } case "n", "no": // 跳过当前项 fmt.Printf("跳过%s.%s的更新\n", item.table, item.key) default: fmt.Printf("无效选择%s,跳过%s.%s的更新\n", choice, item.table, item.key) } } return nil } // checkExists 集中判断配置项是否存在(核心判断逻辑) func (m *model) checkExists(conn *sql.DB, table, key string) (bool, error) { var count int // 通用存在性检查SQL(假设所有表都有key字段作为主键) query := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE `key` = ?", table) err := conn.QueryRow(query, key).Scan(&count) if err != nil { // 表不存在也视为"不存在"(可选:根据实际需求调整,比如先建表) if strings.Contains(err.Error(), "table not found") { return false, nil } return false, err } return count > 0, nil } // upsertConfig 统一处理插入/更新逻辑 func (m *model) upsertConfig(conn *sql.DB, table, key string, val interface{}, update bool) error { var query string if !update { // 插入:假设表结构为(key, value) query = fmt.Sprintf("INSERT INTO %s (`key`, `value`) VALUES (?, ?)", table) } else { // 更新 query = fmt.Sprintf("UPDATE %s SET `value` = ? WHERE `key` = ?", table) } // 处理参数顺序(更新和插入的参数顺序不同) var args []interface{} if !update { args = []interface{}{key, val} } else { args = []interface{}{val, key} } _, err := conn.Exec(query, args...) return err } // askUserChoice 询问用户操作选择 func (m *model) askUserChoice(table, key string) (string, error) { reader := bufio.NewReader(os.Stdin) fmt.Printf("配置项%s.%s已存在,选择操作(y/yes=覆盖, n/no=跳过, a/all=全量覆盖后续所有): ", table, key) input, err := reader.ReadString('\n') if err != nil { return "", err } // 去除空格和换行 return strings.TrimSpace(input), nil } // 获取系统网络接口 func getNetworkInterfaces() []string { // 实现获取系统网络接口的逻辑 // 例如:使用 net.Interface() 函数获取系统网络接口 // 返回一个字符串切片,包含系统网络接口的名称 interfaces, err := net.Interfaces() if err != nil { return []string{utils.NoAvailableNetworkInterfaces} } var result []string for _, iface := range interfaces { // 跳过 loopback 接口 if iface.Flags&net.FlagLoopback != 0 { continue } result = append(result, iface.Name) } if len(result) == 0 { return []string{utils.NoAvailableNetworkInterfaces} } return result }