commit d7cd899983afe892ff8ae26513f27d40faf89d2d Author: kelvin Date: Sat Feb 14 05:36:00 2026 +0800 ok-1 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..6c3beeb --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "editor.tabSize": 4, + "editor.detectIndentation": true +} \ No newline at end of file diff --git a/cmd/init/config.go b/cmd/init/config.go new file mode 100644 index 0000000..1e4a24a --- /dev/null +++ b/cmd/init/config.go @@ -0,0 +1,41 @@ +package initcmd + +import ( + "fmt" + "os" + + "sunhpc/internal/auth" + "sunhpc/internal/config" + "sunhpc/internal/log" + + "github.com/spf13/cobra" +) + +var configCmd = &cobra.Command{ + Use: "config", + Short: "生成基础配置文件", + Long: "创建 /etc/sunhpc 目录并生成所有默认配置文件(若目录已存在则跳过)", + RunE: func(cmd *cobra.Command, args []string) error { + if err := auth.RequireRoot(); err != nil { + return err + } + + // 检查目录是否已存在 + if _, err := os.Stat(config.BaseDir); err == nil { + log.Warnf("配置目录 %s 已存在,跳过初始化", config.BaseDir) + return nil + } + + log.Info("初始化 SunHPC 配置目录...") + if err := config.InitDirs(); err != nil { + return fmt.Errorf("创建目录失败: %v", err) + } + + if err := config.CreateDefaultConfigs(); err != nil { + return fmt.Errorf("生成默认配置文件失败: %v", err) + } + + log.Info("配置文件已生成,请根据需要编辑 /etc/sunhpc/ 下的 YAML 文件") + return nil + }, +} diff --git a/cmd/init/database.go b/cmd/init/database.go new file mode 100644 index 0000000..942c498 --- /dev/null +++ b/cmd/init/database.go @@ -0,0 +1,193 @@ +package initcmd + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" + + "sunhpc/internal/auth" + "sunhpc/internal/db" + "sunhpc/internal/log" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +const ( + appName string = "sunhpc" + defaultDBPath string = "/var/lib/sunhpc" + defaultDBName string = "sunhpc.db" +) + +var ( + forceDB bool + dbPath string + dbName string + dbFullPath string +) + +func initDBPathWithViper() error { + /* + 从 Viper 配置文件获取数据库路径 + 配置文件里的键要和 Viper.GetXXX 的键对应 + 配置文件格式: + db: + path: "/tmp/sunhpc" # 自定义数据库路径 + name: "my_sunhpc.db" # 自定义数据库名 + */ + + log.Infof("从 Viper 配置文件获取数据库路径...") + + // ========== 第一步:设置 Viper 配置文件规则(核心) ========== + // 1. 设置Viper基础规则 + viper.SetConfigType("yaml") // 配置文件类型 + viper.SetConfigName("config") // 配置文件名(不带后缀) + viper.SetEnvPrefix(appName) // 环境变量前缀:SUNHPC_ + viper.AutomaticEnv() // 自动读取环境变量(可选,增强兼容性) + + // 2. 添加配置文件搜索目录(Viper 会按顺序查找,找到第一个就停止) + // 优先级:当前目录 → 用户级目录 → 系统级目录 + + // ① 当前目录(开发/测试常用) + viper.AddConfigPath(".") + + // ② Linux/macOS 用户级目录(~/.config/sunhpc/) + if homeDir, err := os.UserHomeDir(); err == nil { + viper.AddConfigPath(filepath.Join(homeDir, ".config", appName)) + } + // ③ Linux/macOS 系统级目录(/etc/sunhpc/) + viper.AddConfigPath(filepath.Join("/etc", appName)) + + // ========== 第二步:设置默认值(最低优先级) ========== + viper.SetDefault("db.path", defaultDBPath) + viper.SetDefault("db.name", defaultDBName) + + // ========== 第三步:绑定环境变量(优先级高于默认值,低于配置文件) ========== + viper.BindEnv("db.path", "DB_PATH") // 绑定 SUNHPC_DB_PATH → db.path + viper.BindEnv("db.name", "DB_NAME") // 绑定 SUNHPC_DB_NAME → db.name + + // ========== 第四步:读取配置文件(优先级高于环境变量,低于默认值) ========== + if err := viper.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + log.Info("未找到配置文件,将使用环境变量/默认值") + return nil // 配置文件存在但格式错误,返回错误 + } + log.Warnf("读取配置文件失败: %v", err) + return fmt.Errorf("读取配置文件失败: %w", err) + } + + log.Infof("成功加载配置文件: %s", viper.ConfigFileUsed()) + return nil +} + +func initDBPath() error { + + // 1. 从 Viper 配置文件获取数据库路径(加载配置文件->环境变量->默认值) + if err := initDBPathWithViper(); err != nil { + return fmt.Errorf("Viper初始化数据库失败: %v", err) + } + + // 2. 从Viper获取数据库路径 + dbPath = viper.GetString("db.path") + dbName = viper.GetString("db.name") + + // 3. 拼接数据库路径 + dbFullPath = filepath.Join(dbPath, dbName) + log.Infof("数据库完整路径: %s", dbFullPath) + + // 3. 检查数据库文件是否存在 + dir := filepath.Dir(dbFullPath) + // 4. 检查数据库目录是否存在 + if _, err := os.Stat(dir); os.IsNotExist(err) { + log.Infof("数据库目录不存在,创建目录: %s", dir) + + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("创建数据库目录失败: %v", err) + } + } + return nil +} + +var databaseCmd = &cobra.Command{ + Use: "database", + Short: "初始化数据库", + Long: `初始化SQLite数据库,创建所有表结构和默认数据。 + +示例: + sunhpc init database # 初始化数据库 + sunhpc init database --force # 强制重新初始化`, + + RunE: func(cmd *cobra.Command, args []string) error { + if err := auth.RequireRoot(); err != nil { + return err + } + + log.Debug("初始化数据库...") + + if err := initDBPath(); err != nil { + return fmt.Errorf("初始化数据库路径失败: %w", err) + } + + log.Debugf("数据库目录存在: %s", dbPath) + + // 强制模式:用户确认 + if forceDB { + log.Warn("⚠️ 警告:强制重新初始化将清空数据库中的所有数据!") + fmt.Printf("数据库路径: %s\n", dbFullPath) + fmt.Print("确认要重新初始化数据库吗?这将删除所有现有数据。(Y/yes): ") + + reader := bufio.NewReader(os.Stdin) + input, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("读取用户输入失败: %v", err) + } + + input = strings.TrimSpace(strings.ToLower(input)) + if input != "y" && input != "yes" { + log.Info("操作已取消") + return nil + } + + log.Info("用户确认重新初始化数据库") + } + + // 数据库存在且不是强制模式则跳过初始化 + if _, err := os.Stat(dbFullPath); err == nil && !forceDB { + log.Infof("数据库文件已存在: %s", dbFullPath) + return nil + } + + // 初始化数据库(使用配置的路径) + database, err := db.GetInstanceWithConfig(dbPath, dbName) + if err != nil { + return fmt.Errorf("初始化数据库失败: %v", err) + } + defer database.Close() + + // 如果是强制模式,设置强制重新初始化标志 + if forceDB { + database.SetForceInit(true) + log.Info("强制重新初始化数据库表...") + + // 关闭现有连接以触发重新连接 + if err := database.CloseConnection(); err != nil { + return fmt.Errorf("关闭现有数据库连接失败: %v", err) + } + + // 重新连接并初始化 + if err := database.Connect(); err != nil { + return fmt.Errorf("强制重新初始化数据库失败: %v", err) + } + } + + log.Infof("数据库初始化成功: %s", dbFullPath) + return nil + }, +} + +func init() { + databaseCmd.Flags().BoolVarP(&forceDB, "force", "f", false, "强制重新初始化,删除现有数据库") + Cmd.AddCommand(databaseCmd) +} diff --git a/cmd/init/init.go b/cmd/init/init.go new file mode 100644 index 0000000..6506f0d --- /dev/null +++ b/cmd/init/init.go @@ -0,0 +1,17 @@ +package initcmd + +import ( + "github.com/spf13/cobra" +) + +var Cmd = &cobra.Command{ + Use: "init", + Short: "初始化集群配置", + Long: "初始化 SunHPC 配置文件、数据库、系统参数及相关服务", +} + +func init() { + Cmd.AddCommand(configCmd) + Cmd.AddCommand(systemCmd) + Cmd.AddCommand(serviceCmd) +} diff --git a/cmd/init/service.go b/cmd/init/service.go new file mode 100644 index 0000000..44bbe96 --- /dev/null +++ b/cmd/init/service.go @@ -0,0 +1,37 @@ +package initcmd + +import ( + "fmt" + + "sunhpc/internal/auth" + "sunhpc/internal/config" + "sunhpc/internal/log" + "sunhpc/internal/service" + + "github.com/spf13/cobra" +) + +var serviceCmd = &cobra.Command{ + Use: "service", + Short: "根据配置文件初始化服务", + Long: `读取 /etc/sunhpc/services.yaml 并部署/配置相关服务。 +支持 HTTPD、TFTPD、DHCPD 等。`, + RunE: func(cmd *cobra.Command, args []string) error { + if err := auth.RequireRoot(); err != nil { + return err + } + + svcCfg, err := config.LoadServices() + if err != nil { + return fmt.Errorf("加载 services.yaml 失败: %v", err) + } + + log.Info("开始部署服务...") + if err := service.Deploy(svcCfg); err != nil { + return fmt.Errorf("服务部署失败: %v", err) + } + + log.Info("服务初始化完成") + return nil + }, +} diff --git a/cmd/init/system.go b/cmd/init/system.go new file mode 100644 index 0000000..79ed85d --- /dev/null +++ b/cmd/init/system.go @@ -0,0 +1,49 @@ +package initcmd + +import ( + "fmt" + + "sunhpc/internal/auth" + "sunhpc/internal/config" + "sunhpc/internal/system" + + "github.com/spf13/cobra" +) + +var ( + dryRun bool // --dry-run -n: 仅模拟执行,不实际应用 + verbose bool // --verbose -v: 启用详细日志输出 +) + +var systemCmd = &cobra.Command{ + Use: "system [flags]", + Short: "根据配置文件初始化系统", + Long: `读取 /etc/sunhpc/sunhpc.yaml 中的系统配置项并应用到当前节点。 + 示例: + sunhpc init system # 应用所有配置项 + sunhpc init system --dry-run # 仅模拟执行,不实际应用 + sunhpc init system --verbose # 启用详细日志输出 + `, + RunE: func(cmd *cobra.Command, args []string) error { + // 权限检查:必须以 root 或 sudo 运行 + if err := auth.RequireRoot(); err != nil { + return err + } + + // 加载主配置 + cfg, err := config.LoadSunHPC() + if err != nil { + return fmt.Errorf("加载 sunhpc.yaml 失败: %v", err) + } + + // 统一应用所有配置 + return system.ApplyAll(cfg) + }, +} + +// init 初始化 systemCmd 的标志,添加长参数和段参数. +func init() { + // 注册长参数, 布尔参数 + systemCmd.Flags().BoolVarP(&dryRun, "dry-run", "n", false, "仅模拟执行,不实际应用") + systemCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "启用详细日志输出") +} diff --git a/cmd/node.go b/cmd/node.go new file mode 100644 index 0000000..8118fbe --- /dev/null +++ b/cmd/node.go @@ -0,0 +1,113 @@ +package cmd + +import ( + "fmt" + + "sunhpc/internal/db" + "sunhpc/internal/log" + + "github.com/spf13/cobra" +) + +var nodeCmd = &cobra.Command{ + Use: "node", + Short: "节点管理", + Long: "管理集群节点,包括添加、删除、查询等操作", +} + +var nodeListCmd = &cobra.Command{ + Use: "list", + Short: "列出所有节点", + RunE: func(cmd *cobra.Command, args []string) error { + log.Info("查询节点列表...") + + // 获取数据库实例(自动使用之前配置的路径) + database, err := db.GetInstance() + if err != nil { + return fmt.Errorf("获取数据库连接失败: %v", err) + } + defer database.Close() + + // 执行查询 + _, err = database.Execute("SELECT id, name, rack, rank, cpus, memory, disk, os, kernel FROM nodes ORDER BY name") + if err != nil { + return fmt.Errorf("查询节点失败: %v", err) + } + + // 获取所有结果 + rows, err := database.FetchAll() + if err != nil { + return fmt.Errorf("获取结果失败: %v", err) + } + + if len(rows) == 0 { + log.Info("暂无节点数据") + return nil + } + + // 打印结果 + fmt.Printf("%-5s %-20s %-8s %-8s %-8s %-10s %-10s %-10s\n", + "ID", "名称", "机架", "排名", "CPU", "内存", "磁盘", "操作系统") + fmt.Println("----------------------------------------------------------------------------------") + for _, row := range rows { + fmt.Printf("%-5v %-20s %-8v %-8v %-8v %-10v %-10v %-10s\n", + row["id"], row["name"], row["rack"], row["rank"], + row["cpus"], row["memory"], row["disk"], row["os"]) + } + + return nil + }, +} + +var nodeAddCmd = &cobra.Command{ + Use: "add ", + Short: "添加节点", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + nodeName := args[0] + log.Infof("添加节点: %s", nodeName) + + database, err := db.GetInstance() + if err != nil { + return fmt.Errorf("获取数据库连接失败: %v", err) + } + defer database.Close() + + // 插入节点 + _, err = database.Execute( + "INSERT INTO nodes (name, rack, rank, cpus, memory, disk, os, kernel) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + nodeName, rack, rank, cpus, memory, disk, nodeOS, kernel, + ) + if err != nil { + return fmt.Errorf("添加节点失败: %v", err) + } + + log.Infof("节点 %s 添加成功", nodeName) + return nil + }, +} + +var ( + rack int + rank int + cpus int + memory int + disk int + nodeOS string + kernel string +) + +func init() { + // 添加子命令 + nodeCmd.AddCommand(nodeListCmd) + nodeCmd.AddCommand(nodeAddCmd) + + // 添加参数 + nodeAddCmd.Flags().IntVar(&rack, "rack", 0, "机架号") + nodeAddCmd.Flags().IntVar(&rank, "rank", 0, "排名") + nodeAddCmd.Flags().IntVar(&cpus, "cpus", 0, "CPU核心数") + nodeAddCmd.Flags().IntVar(&memory, "memory", 0, "内存大小(GB)") + nodeAddCmd.Flags().IntVar(&disk, "disk", 0, "磁盘大小(GB)") + nodeAddCmd.Flags().StringVar(&nodeOS, "os", "", "操作系统") + nodeAddCmd.Flags().StringVar(&kernel, "kernel", "", "内核版本") +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..c746cc3 --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,73 @@ +package cmd + +import ( + initcmd "sunhpc/cmd/init" + "sunhpc/cmd/soft" + "sunhpc/internal/log" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var ( + cfgFile string + verbose bool + noColor bool +) + +var rootCmd = &cobra.Command{ + Use: "sunhpc", + Short: "SunHPC - HPC集群一体化运维工具", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + // 初始化日志 + log.Init(verbose) + + // 是否禁用颜色 + if noColor { + log.EnableColor(false) + } + + log.Debugf("命令: %s", cmd.Name()) + log.Debugf("详细模式: %v", verbose) + }, + PersistentPostRun: func(cmd *cobra.Command, args []string) { + // 同步日志 + log.Sync() + log.Close() + }, +} + +func Execute() error { + return rootCmd.Execute() +} + +func init() { + cobra.OnInitialize(initConfig) + + rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "配置文件路径 (默认为 /etc/sunhpc/sunhpc.yaml)") + rootCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "启用详细日志输出") + rootCmd.PersistentFlags().BoolVar(&noColor, "no-color", false, "禁用彩色输出") + + // 注册一级子命令 + rootCmd.AddCommand(initcmd.Cmd) + rootCmd.AddCommand(soft.Cmd) + rootCmd.AddCommand(nodeCmd) +} + +func initConfig() { + if cfgFile != "" { + viper.SetConfigFile(cfgFile) + } else { + viper.AddConfigPath("/etc/sunhpc") + viper.SetConfigType("yaml") + viper.SetConfigName("sunhpc") + } + + viper.AutomaticEnv() + + if err := viper.ReadInConfig(); err == nil { + log.Infof("使用配置文件: %s", viper.ConfigFileUsed()) + } else { + log.Debugf("未找到配置文件: %v", err) + } +} diff --git a/cmd/soft/install.go b/cmd/soft/install.go new file mode 100644 index 0000000..fa378cd --- /dev/null +++ b/cmd/soft/install.go @@ -0,0 +1,109 @@ +// cmd/soft/install.go +package soft + +import ( + "fmt" + "sunhpc/internal/auth" + "sunhpc/internal/log" + "sunhpc/internal/soft" + + "github.com/spf13/cobra" +) + +var ( + installType string // --type, -t + srcPath string // --src-path, -s + binPath string // --bin-path, -b + prefix string // --prefix, -p + version string // --version, -v + forceInstall bool // --force, -f + dryRun bool // --dry-run, -n + keepSource bool // --keep-source, -k + jobs int // --jobs, -j + offlineMode bool // --offline, -o +) + +var installCmd = &cobra.Command{ + Use: "install ", + Short: "安装软件", + Long: `安装指定的软件包,支持多种安装方式。 + +安装类型: + source - 从源码编译安装 + binary - 从二进制压缩包安装 + rpm - 通过 RPM 包管理器安装 + deb - 通过 APT 包管理器安装 + +示例: + sunhpc soft install vasp --type source --src-path /tmp/vasp.tar.gz + sunhpc soft install openmpi --type binary --bin-path openmpi.tar.gz -p /opt/openmpi + sunhpc soft install htop --type rpm --force + sunhpc soft install nginx --type deb --dry-run`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := auth.RequireRoot(); err != nil { + return err + } + + software := args[0] + + if dryRun { + log.Infof("[干运行] 将要安装 %s", software) + log.Infof(" 安装类型: %s", installType) + if srcPath != "" { + log.Infof(" 源码路径: %s", srcPath) + } + if binPath != "" { + log.Infof(" 二进制包: %s", binPath) + } + if prefix != "" { + log.Infof(" 安装路径: %s", prefix) + } + return nil + } + + ctx := &soft.InstallContext{ + Force: forceInstall, + DryRun: dryRun, + KeepSource: keepSource, + Jobs: jobs, + Offline: offlineMode, + } + + switch installType { + case "source": + return soft.InstallFromSource(software, srcPath, prefix, version, ctx) + case "binary": + return soft.InstallFromBinary(software, binPath, prefix, ctx) + case "rpm", "deb": + return soft.InstallFromPackage(software, installType, ctx) + default: + return fmt.Errorf("不支持的安装类型: %s", installType) + } + }, +} + +func init() { + // 必选参数 + installCmd.Flags().StringVarP(&installType, "type", "t", "", "安装类型: source/binary/rpm/deb") + installCmd.MarkFlagRequired("type") + + // 路径参数 + installCmd.Flags().StringVarP(&srcPath, "src-path", "s", "", "源码路径或URL") + installCmd.Flags().StringVarP(&binPath, "bin-path", "b", "", "二进制压缩包路径") + installCmd.Flags().StringVarP(&prefix, "prefix", "p", "/opt/sunhpc/software", "安装路径") + + // 版本参数 + installCmd.Flags().StringVarP(&version, "version", "v", "", "软件版本号") + + // 行为控制 + installCmd.Flags().BoolVarP(&forceInstall, "force", "f", false, "强制安装,覆盖已有版本") + installCmd.Flags().BoolVarP(&dryRun, "dry-run", "n", false, "仅显示将要执行的操作") + installCmd.Flags().BoolVarP(&keepSource, "keep-source", "k", false, "保留源码文件") + installCmd.Flags().IntVarP(&jobs, "jobs", "j", 4, "编译线程数") + installCmd.Flags().BoolVarP(&offlineMode, "offline", "o", false, "离线模式,不联网下载") + + // 参数互斥 + installCmd.MarkFlagsMutuallyExclusive("src-path", "bin-path") + installCmd.MarkFlagsOneRequired("src-path", "bin-path") +} diff --git a/cmd/soft/soft.go b/cmd/soft/soft.go new file mode 100644 index 0000000..fb16736 --- /dev/null +++ b/cmd/soft/soft.go @@ -0,0 +1,16 @@ +package soft + +import ( + "github.com/spf13/cobra" +) + +var Cmd = &cobra.Command{ + Use: "soft", + Short: "软件包管理", + Long: "安装、卸载、编译软件,支持源码、RPM、DEB、二进制压缩包", +} + +func init() { + Cmd.AddCommand(installCmd) + // 后续可添加 remove、list 等子命令 +} diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..384d46b --- /dev/null +++ b/config.yaml @@ -0,0 +1,3 @@ +db: + path: "/tmp/sunhpc" + name: "sunhpc.db" diff --git a/docs/database_usage.md b/docs/database_usage.md new file mode 100644 index 0000000..4175e02 --- /dev/null +++ b/docs/database_usage.md @@ -0,0 +1,344 @@ +# 数据库使用指南 + +## 📋 概述 + +SunHPC 使用 SQLite 数据库,支持自定义数据库路径和名称。数据库配置通过 `sunhpc init database` 命令初始化,之后所有命令都可以通过单例模式访问同一个数据库实例。 + +## 🚀 快速开始 + +### 1. 初始化数据库 + +```bash +# 使用默认路径 (/var/lib/sunhpc/sunhpc.db) +sunhpc init database + +# 使用自定义配置文件 +sunhpc init database --config /path/to/config.yaml + +# 使用环境变量 +DB_PATH=/opt/sunhpc/data DB_NAME=cluster.db sunhpc init database + +# 强制重新初始化 +sunhpc init database --force +``` + +### 2. 在其他命令中使用数据库 + +```go +package cmd + +import ( + "fmt" + "sunhpc/internal/db" + "sunhpc/internal/log" + "github.com/spf13/cobra" +) + +var myCmd = &cobra.Command{ + Use: "mycommand", + Short: "我的命令", + RunE: func(cmd *cobra.Command, args []string) error { + // 获取数据库实例(自动使用配置的路径) + database, err := db.GetInstance() + if err != nil { + return fmt.Errorf("获取数据库连接失败: %v", err) + } + defer database.Close() + + // 执行查询 + _, err = database.Execute("SELECT * FROM nodes") + if err != nil { + return fmt.Errorf("查询失败: %v", err) + } + + // 获取结果 + rows, err := database.FetchAll() + if err != nil { + return fmt.Errorf("获取结果失败: %v", err) + } + + // 处理结果 + for _, row := range rows { + log.Infof("节点: %v", row) + } + + return nil + }, +} +``` + +## 🔧 配置说明 + +### 配置文件格式 + +创建 `config.yaml` 文件: + +```yaml +db: + path: "/opt/sunhpc/data" # 数据库目录路径 + name: "my_cluster.db" # 数据库文件名 +``` + +### 配置优先级 + +从高到低: + +1. **配置文件**:`config.yaml` 中的 `db.path` 和 `db.name` +2. **环境变量**:`DB_PATH` 和 `DB_NAME` +3. **默认值**:`/var/lib/sunhpc` 和 `sunhpc.db` + +### 环境变量 + +```bash +# 设置数据库路径 +export DB_PATH=/tmp/sunhpc +export DB_NAME=test.db + +# 使用环境变量 +sunhpc init database +``` + +## 📊 数据库 API + +### 获取数据库实例 + +```go +// 方式1:使用默认配置(推荐) +database, err := db.GetInstance() + +// 方式2:指定路径和名称(仅在初始化时使用) +database, err := db.GetInstanceWithConfig("/path/to/db", "mydb.db") + +// 检查实例是否已配置 +if db.IsInstanceConfigured() { + dbPath, dbName := db.GetInstanceConfig() + log.Infof("数据库: %s/%s", dbPath, dbName) +} +``` + +### 执行 SQL + +```go +// 执行查询 +_, err := database.Execute("SELECT * FROM nodes WHERE name = ?", "node1") + +// 执行插入 +_, err := database.Execute( + "INSERT INTO nodes (name, cpus, memory) VALUES (?, ?, ?)", + "node2", 32, 128, +) + +// 执行更新 +_, err := database.Execute( + "UPDATE nodes SET cpus = ? WHERE name = ?", + 64, "node1", +) + +// 执行删除 +_, err := database.Execute("DELETE FROM nodes WHERE name = ?", "node1") +``` + +### 获取查询结果 + +```go +// 获取单行 +row, err := database.FetchOne() +if err != nil { + return err +} +if row != nil { + log.Infof("节点名称: %s", row["name"]) + log.Infof("CPU数量: %v", row["cpus"]) +} + +// 获取所有行 +rows, err := database.FetchAll() +if err != nil { + return err +} +for _, row := range rows { + log.Infof("节点: %v", row) +} +``` + +## 🎯 使用示例 + +### 示例1:查询节点列表 + +```go +func listNodes() error { + database, err := db.GetInstance() + if err != nil { + return err + } + defer database.Close() + + _, err = database.Execute("SELECT id, name, cpus, memory FROM nodes ORDER BY name") + if err != nil { + return err + } + + rows, err := database.FetchAll() + if err != nil { + return err + } + + for _, row := range rows { + fmt.Printf("%-5s %-20s %-8s %-10s\n", + row["id"], row["name"], row["cpus"], row["memory"]) + } + + return nil +} +``` + +### 示例2:添加节点 + +```go +func addNode(name string, cpus, memory int) error { + database, err := db.GetInstance() + if err != nil { + return err + } + defer database.Close() + + _, err = database.Execute( + "INSERT INTO nodes (name, cpus, memory) VALUES (?, ?, ?)", + name, cpus, memory, + ) + if err != nil { + return fmt.Errorf("添加节点失败: %v", err) + } + + return nil +} +``` + +### 示例3:更新节点 + +```go +func updateNode(name string, cpus, memory int) error { + database, err := db.GetInstance() + if err != nil { + return err + } + defer database.Close() + + _, err = database.Execute( + "UPDATE nodes SET cpus = ?, memory = ? WHERE name = ?", + cpus, memory, name, + ) + if err != nil { + return fmt.Errorf("更新节点失败: %v", err) + } + + return nil +} +``` + +### 示例4:删除节点 + +```go +func deleteNode(name string) error { + database, err := db.GetInstance() + if err != nil { + return err + } + defer database.Close() + + _, err = database.Execute("DELETE FROM nodes WHERE name = ?", name) + if err != nil { + return fmt.Errorf("删除节点失败: %v", err) + } + + return nil +} +``` + +## 🔍 数据库表结构 + +### nodes 表 + +```sql +CREATE TABLE nodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + rack INTEGER DEFAULT 0, + rank INTEGER DEFAULT 0, + membership_id INTEGER, + cpus INTEGER DEFAULT 0, + memory INTEGER DEFAULT 0, + disk INTEGER DEFAULT 0, + os TEXT, + kernel TEXT, + last_state_change DATETIME DEFAULT CURRENT_TIMESTAMP, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); +``` + +### networks 表 + +```sql +CREATE TABLE networks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + node_id INTEGER NOT NULL, + name TEXT, + ip TEXT UNIQUE, + mac TEXT UNIQUE, + subnet_id INTEGER, + interface TEXT DEFAULT 'eth0', + FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, + FOREIGN KEY (subnet_id) REFERENCES subnets(id) ON DELETE SET NULL +); +``` + +### software_installs 表 + +```sql +CREATE TABLE software_installs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + version TEXT, + install_type TEXT, + node_id INTEGER, + status TEXT, + installed_at DATETIME DEFAULT CURRENT_TIMESTAMP, + installed_by TEXT, + FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE SET NULL +); +``` + +## ⚠️ 注意事项 + +1. **单例模式**:`GetInstance()` 使用单例模式,整个程序只创建一个数据库实例 +2. **路径一致性**:所有命令都使用同一个数据库路径,确保数据一致性 +3. **关闭连接**:使用完毕后调用 `database.Close()` 释放资源 +4. **错误处理**:始终检查错误返回值 +5. **SQL注入防护**:使用参数化查询(`?` 占位符) + +## 📝 最佳实践 + +1. **初始化优先**:在程序启动时先执行 `sunhpc init database` +2. **配置管理**:使用配置文件统一管理数据库路径 +3. **事务处理**:复杂操作使用事务确保数据一致性 +4. **日志记录**:记录所有数据库操作,便于调试 +5. **资源释放**:使用 `defer database.Close()` 确保连接关闭 + +## 🆘 常见问题 + +### Q: 如何切换数据库路径? + +A: 重新运行 `sunhpc init database` 命令,指定新的配置文件或环境变量。 + +### Q: 多个命令会创建多个数据库实例吗? + +A: 不会。`GetInstance()` 使用单例模式,整个程序只创建一个实例。 + +### Q: 如何查看当前使用的数据库路径? + +A: 使用 `db.GetInstanceConfig()` 获取配置信息。 + +### Q: 数据库文件不存在会怎样? + +A: 首次调用 `GetInstance()` 时会自动创建数据库文件和表结构。 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b09e75f --- /dev/null +++ b/go.mod @@ -0,0 +1,30 @@ +module sunhpc + +go 1.25.5 + +require ( + github.com/spf13/cobra v1.10.2 + github.com/spf13/viper v1.21.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/fatih/color v1.18.0 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.34 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sirupsen/logrus v1.9.4 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/text v0.28.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..17e108b --- /dev/null +++ b/go.sum @@ -0,0 +1,67 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= +github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= +github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= +github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..7aed4f9 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,14 @@ +package auth + +import ( + "fmt" + "os" +) + +// RequireRoot 检查是否以 root 身份运行 +func RequireRoot() error { + if os.Geteuid() != 0 { + return fmt.Errorf("此操作需要 root 权限,请使用 sudo 或切换到 root 用户") + } + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..3d51259 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,69 @@ +package config + +import ( + "os" + "path/filepath" + + "gopkg.in/yaml.v3" +) + +const ( + BaseDir = "/etc/sunhpc" + LogDir = "/var/log/sunhpc" + TmplDir = BaseDir + "/tmpl.d" +) + +var ( + SunHPCFile = filepath.Join(BaseDir, "sunhpc.yaml") + NodesFile = filepath.Join(BaseDir, "nodes.yaml") + NetworkFile = filepath.Join(BaseDir, "network.yaml") + DisksFile = filepath.Join(BaseDir, "disks.yaml") + ServicesFile = filepath.Join(BaseDir, "services.yaml") + FirewallFile = filepath.Join(BaseDir, "iptables.yaml") +) + +// InitDirs 创建所有必需目录 +func InitDirs() error { + dirs := []string{ + BaseDir, + TmplDir, + LogDir, + } + for _, d := range dirs { + if err := os.MkdirAll(d, 0755); err != nil { + return err + } + } + return nil +} + +// CreateDefaultConfigs 生成默认 YAML 配置文件 +func CreateDefaultConfigs() error { + files := map[string]interface{}{ + SunHPCFile: DefaultSunHPC(), + NodesFile: DefaultNodes(), + NetworkFile: DefaultNetwork(), + DisksFile: DefaultDisks(), + ServicesFile: DefaultServices(), + FirewallFile: DefaultFirewall(), + } + + for path, data := range files { + if err := writeYAML(path, data); err != nil { + return err + } + } + return nil +} + +func writeYAML(path string, data interface{}) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + enc := yaml.NewEncoder(f) + defer enc.Close() + return enc.Encode(data) +} diff --git a/internal/config/defaults.go b/internal/config/defaults.go new file mode 100644 index 0000000..354987b --- /dev/null +++ b/internal/config/defaults.go @@ -0,0 +1,128 @@ +package config + +// SunHPC 主配置 +type SunHPCConfig struct { + Hostname string `yaml:"hostname"` + MOTD string `yaml:"motd"` + Sysctl map[string]string `yaml:"sysctl"` + SELinux string `yaml:"selinux"` // enforcing, permissive, disabled + SSH SSHConfig `yaml:"ssh"` +} + +type SSHConfig struct { + PermitRootLogin string `yaml:"permit_root_login"` + PasswordAuth string `yaml:"password_authentication"` +} + +func DefaultSunHPC() *SunHPCConfig { + return &SunHPCConfig{ + Hostname: "sunhpc-master", + MOTD: "Welcome to SunHPC Cluster\n", + Sysctl: map[string]string{ + "net.ipv4.ip_forward": "1", + "vm.swappiness": "10", + }, + SELinux: "enforcing", + SSH: SSHConfig{ + PermitRootLogin: "yes", + PasswordAuth: "yes", + }, + } +} + +// Nodes 节点配置 +type NodesConfig struct { + Nodes []Node `yaml:"nodes"` +} + +type Node struct { + Hostname string `yaml:"hostname"` + MAC string `yaml:"mac"` + IP string `yaml:"ip"` + Role string `yaml:"role"` // master, compute, login +} + +func DefaultNodes() *NodesConfig { + return &NodesConfig{ + Nodes: []Node{ + {Hostname: "master", MAC: "00:11:22:33:44:55", IP: "192.168.1.1", Role: "master"}, + }, + } +} + +// Network 网络配置 +type NetworkConfig struct { + Interface string `yaml:"interface"` + Subnet string `yaml:"subnet"` + Netmask string `yaml:"netmask"` + Gateway string `yaml:"gateway"` + DNSServers []string `yaml:"dns_servers"` +} + +func DefaultNetwork() *NetworkConfig { + return &NetworkConfig{ + Interface: "eth0", + Subnet: "192.168.1.0", + Netmask: "255.255.255.0", + Gateway: "192.168.1.1", + DNSServers: []string{"8.8.8.8", "114.114.114.114"}, + } +} + +// Disks 磁盘配置 +type DisksConfig struct { + Disks []Disk `yaml:"disks"` +} + +type Disk struct { + Device string `yaml:"device"` + Mount string `yaml:"mount"` + FSType string `yaml:"fstype"` + Options string `yaml:"options"` +} + +func DefaultDisks() *DisksConfig { + return &DisksConfig{ + Disks: []Disk{ + {Device: "/dev/sda1", Mount: "/", FSType: "ext4", Options: "defaults"}, + }, + } +} + +// Services 服务配置 +type ServicesConfig struct { + HTTPD Service `yaml:"httpd"` + TFTPD Service `yaml:"tftpd"` + DHCPD Service `yaml:"dhcpd"` +} + +type Service struct { + Enabled bool `yaml:"enabled"` + Config string `yaml:"config,omitempty"` +} + +func DefaultServices() *ServicesConfig { + return &ServicesConfig{ + HTTPD: Service{Enabled: true}, + TFTPD: Service{Enabled: true}, + DHCPD: Service{Enabled: true}, + } +} + +// Firewall 防火墙配置 +type FirewallConfig struct { + DefaultPolicy string `yaml:"default_policy"` + Rules []string `yaml:"rules"` +} + +func DefaultFirewall() *FirewallConfig { + return &FirewallConfig{ + DefaultPolicy: "DROP", + Rules: []string{ + "-A INPUT -m state --state ESTABLISHED,RELATED -j ACCEPT", + "-A INPUT -p icmp -j ACCEPT", + "-A INPUT -i lo -j ACCEPT", + "-A INPUT -p tcp --dport 22 -j ACCEPT", + }, + } +} diff --git a/internal/config/loaders.go b/internal/config/loaders.go new file mode 100644 index 0000000..2e14686 --- /dev/null +++ b/internal/config/loaders.go @@ -0,0 +1,43 @@ +package config + +import ( + "os" + + "gopkg.in/yaml.v3" +) + +func LoadSunHPC() (*SunHPCConfig, error) { + return loadYAML[SunHPCConfig](SunHPCFile) +} + +func LoadNodes() (*NodesConfig, error) { + return loadYAML[NodesConfig](NodesFile) +} + +func LoadNetwork() (*NetworkConfig, error) { + return loadYAML[NetworkConfig](NetworkFile) +} + +func LoadDisks() (*DisksConfig, error) { + return loadYAML[DisksConfig](DisksFile) +} + +func LoadServices() (*ServicesConfig, error) { + return loadYAML[ServicesConfig](ServicesFile) +} + +func LoadFirewall() (*FirewallConfig, error) { + return loadYAML[FirewallConfig](FirewallFile) +} + +func loadYAML[T any](path string) (*T, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var cfg T + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, err + } + return &cfg, nil +} diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 0000000..d8e774a --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,794 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "sunhpc/internal/log" + + _ "github.com/mattn/go-sqlite3" +) + +// 全局单例 +var ( + globalDB *DB + once sync.Once +) + +// DB 核心数据库类 - 对应Rocks的Database类 +type DB struct { + // 连接参数 + dbUser string + dbPasswd string + dbHost string + dbName string + dbPath string + dbSocket string + verbose bool + forceInit bool + + // 连接对象 + engine *sql.DB // 连接池 + conn *sql.Conn // 当前连接 + results *sql.Rows // 当前结果集 + + // 线程本地存储模拟 + sessions sync.Map + + mu sync.RWMutex +} + +// NewDB 创建新实例 +func NewDB() *DB { + return &DB{ + dbUser: "", + dbPasswd: "", + dbHost: "localhost", + dbName: "sunhpc", + dbPath: "/var/lib/sunhpc", + dbSocket: "/var/lib/sunhpc/mysql/mysql.sock", + verbose: false, + } +} + +// ==================== 连接参数设置/获取 ==================== + +func (db *DB) SetDBPasswd(passwd string) { + db.mu.Lock() + defer db.mu.Unlock() + db.dbPasswd = passwd +} + +func (db *DB) GetDBPasswd() string { + db.mu.RLock() + if db.dbPasswd != "" { + db.mu.RUnlock() + return db.dbPasswd + } + db.mu.RUnlock() + + db.mu.Lock() + defer db.mu.Unlock() + + // 从配置文件读取密码 + username := db.GetDBUsername() + var filename string + switch username { + case "root": + filename = "/root/.sunhpc.my.cnf" + default: + filename = fmt.Sprintf("/home/%s/.sunhpc.my.cnf", username) + } + + data, err := ioutil.ReadFile(filename) + if err != nil { + return "" + } + + lines := strings.Split(string(data), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + parts := strings.Split(line, "=") + if len(parts) == 2 && strings.TrimSpace(parts[0]) == "password" { + db.dbPasswd = strings.TrimSpace(parts[1]) + break + } + } + return db.dbPasswd +} + +func (db *DB) SetDBUsername(name string) { + db.mu.Lock() + defer db.mu.Unlock() + db.dbUser = name +} + +func (db *DB) GetDBUsername() string { + db.mu.RLock() + if db.dbUser != "" { + db.mu.RUnlock() + return db.dbUser + } + db.mu.RUnlock() + + db.mu.Lock() + defer db.mu.Unlock() + db.dbUser = os.Getenv("USER") + return db.dbUser +} + +func (db *DB) SetDBHostname(host string) { + db.mu.Lock() + defer db.mu.Unlock() + db.dbHost = host +} + +func (db *DB) GetDBHostname() string { + db.mu.RLock() + defer db.mu.RUnlock() + return db.dbHost +} + +func (db *DB) SetDBName(name string) { + db.mu.Lock() + defer db.mu.Unlock() + db.dbName = name +} + +func (db *DB) GetDBName() string { + db.mu.RLock() + defer db.mu.RUnlock() + return db.dbName +} + +func (db *DB) SetDBPath(path string) { + db.mu.Lock() + defer db.mu.Unlock() + db.dbPath = path +} + +func (db *DB) GetDBPath() string { + db.mu.RLock() + defer db.mu.RUnlock() + return db.dbPath +} + +func (db *DB) SetVerbose(verbose bool) { + db.mu.Lock() + defer db.mu.Unlock() + db.verbose = verbose +} + +func (db *DB) SetForceInit(force bool) { + db.mu.Lock() + defer db.mu.Unlock() + db.forceInit = force +} + +// ==================== 连接管理 ==================== + +// Connect 连接数据库 +func (db *DB) Connect() error { + log.Debug("连接数据库...") + db.mu.Lock() + defer db.mu.Unlock() + + log.Debug("检查 SUNHPCDEBUG 环境变量...") + if os.Getenv("SUNHPCDEBUG") != "" { + db.verbose = true + } + + // 使用SQLite + dbFullPath := filepath.Join(db.dbPath, db.dbName+".db") + log.Debugf("数据库路径: %s", dbFullPath) + + // 确保目录存在 + log.Debug("确保数据库目录存在...") + os.MkdirAll(db.dbPath, 0755) + + engine, err := sql.Open("sqlite3", dbFullPath+"?_foreign_keys=on&_journal_mode=WAL") + log.Debugf("打开数据库连接...") + if err != nil { + return fmt.Errorf("打开数据库失败: %v", err) + } + + engine.SetMaxOpenConns(10) + engine.SetMaxIdleConns(5) + engine.SetConnMaxLifetime(time.Hour) + + db.engine = engine + + conn, err := engine.Conn(context.Background()) + log.Debugf("获取数据库连接...") + if err != nil { + return fmt.Errorf("获取连接失败: %v", err) + } + db.conn = conn + + // 初始化数据库表 + if err := db.initSchema(); err != nil { + return fmt.Errorf("初始化数据库表失败: %v", err) + } + + if db.verbose { + log.Infof("数据库连接成功: %s", dbFullPath) + } + + return nil +} + +// initSchema 初始化数据库表结构 - 所有表定义在这里 +func (db *DB) initSchema() error { + log.Debug("初始化数据库表结构...") + + // 检查 nodes 表是否已存在 + var tableName string + err := db.engine.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='nodes'").Scan(&tableName) + + if err == nil && !db.forceInit { + log.Debug("数据库表已存在,跳过初始化") + return nil + } + + if db.forceInit { + log.Warn("强制重新初始化数据库表结构...") + } else { + log.Info("首次初始化数据库表结构...") + } + + // 如果强制初始化,先删除所有表 + if db.forceInit { + log.Info("删除现有表...") + dropSQLs := []string{ + `DROP TABLE IF EXISTS resolvechain;`, + `DROP TABLE IF EXISTS hostselections;`, + `DROP TABLE IF EXISTS attributes;`, + `DROP TABLE IF EXISTS catindexes;`, + `DROP TABLE IF EXISTS categories;`, + `DROP TABLE IF EXISTS node_attrs;`, + `DROP TABLE IF EXISTS aliases;`, + `DROP TABLE IF EXISTS networks;`, + `DROP TABLE IF EXISTS subnets;`, + `DROP TABLE IF EXISTS software_installs;`, + `DROP TABLE IF EXISTS memberships;`, + `DROP TABLE IF EXISTS appliances;`, + `DROP TABLE IF EXISTS nodes;`, + } + + for _, sql := range dropSQLs { + if _, err := db.engine.Exec(sql); err != nil { + log.Warnf("删除表失败: %v", err) + } + } + log.Info("现有表已删除") + } + + // 开启事务 + tx, err := db.engine.Begin() + if err != nil { + return fmt.Errorf("开启事务失败: %v", err) + } + + // 使用exec执行,每条SQL单独执行 + sqls := []string{ + // 创建表 - 注意创建顺序(先创建主表,再创建有外键的表) + `CREATE TABLE IF NOT EXISTS nodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + rack INTEGER DEFAULT 0, + rank INTEGER DEFAULT 0, + membership_id INTEGER, + cpus INTEGER DEFAULT 0, + memory INTEGER DEFAULT 0, + disk INTEGER DEFAULT 0, + os TEXT, + kernel TEXT, + last_state_change DATETIME DEFAULT CURRENT_TIMESTAMP, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + );`, + + `CREATE TABLE IF NOT EXISTS appliances ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + description TEXT, + node_type TEXT DEFAULT 'compute' + );`, + + `CREATE TABLE IF NOT EXISTS memberships ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + appliance_id INTEGER, + FOREIGN KEY (appliance_id) REFERENCES appliances(id) ON DELETE SET NULL + );`, + + `CREATE TABLE IF NOT EXISTS subnets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE, + network TEXT, + netmask TEXT, + gateway TEXT, + dns_zone TEXT, + is_private INTEGER DEFAULT 1 + );`, + + `CREATE TABLE IF NOT EXISTS networks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + node_id INTEGER NOT NULL, + name TEXT, + ip TEXT UNIQUE, + mac TEXT UNIQUE, + subnet_id INTEGER, + interface TEXT DEFAULT 'eth0', + FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, + FOREIGN KEY (subnet_id) REFERENCES subnets(id) ON DELETE SET NULL + );`, + + `CREATE TABLE IF NOT EXISTS aliases ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + node_id INTEGER NOT NULL, + name TEXT NOT NULL, + FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, + UNIQUE(node_id, name) + );`, + + `CREATE TABLE IF NOT EXISTS categories ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE + );`, + + `CREATE TABLE IF NOT EXISTS catindexes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + category_id INTEGER NOT NULL, + FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, + UNIQUE(name, category_id) + );`, + + `CREATE TABLE IF NOT EXISTS attributes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + attr TEXT NOT NULL, + value TEXT, + category_id INTEGER NOT NULL, + catindex_id INTEGER NOT NULL, + FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, + FOREIGN KEY (catindex_id) REFERENCES catindexes(id) ON DELETE CASCADE, + UNIQUE(attr, category_id, catindex_id) + );`, + + `CREATE TABLE IF NOT EXISTS node_attrs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + node_id INTEGER NOT NULL, + attr TEXT NOT NULL, + value TEXT, + FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, + UNIQUE(node_id, attr) + );`, + + `CREATE TABLE IF NOT EXISTS hostselections ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + host_id INTEGER NOT NULL, + category_id INTEGER NOT NULL, + selection TEXT NOT NULL, + FOREIGN KEY (host_id) REFERENCES nodes(id) ON DELETE CASCADE, + FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, + UNIQUE(host_id, category_id, selection) + );`, + + `CREATE TABLE IF NOT EXISTS resolvechain ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + category_id INTEGER NOT NULL, + precedence INTEGER NOT NULL, + FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, + UNIQUE(category_id, precedence) + );`, + + `CREATE TABLE IF NOT EXISTS software_installs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + version TEXT, + install_type TEXT, + node_id INTEGER, + status TEXT, + installed_at DATETIME DEFAULT CURRENT_TIMESTAMP, + installed_by TEXT, + FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE SET NULL + );`, + + // 创建索引 + `CREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name);`, + `CREATE INDEX IF NOT EXISTS idx_networks_ip ON networks(ip);`, + `CREATE INDEX IF NOT EXISTS idx_networks_mac ON networks(mac);`, + `CREATE INDEX IF NOT EXISTS idx_attributes_lookup ON attributes(attr, category_id, catindex_id);`, + `CREATE INDEX IF NOT EXISTS idx_node_attrs_lookup ON node_attrs(node_id, attr);`, + `CREATE INDEX IF NOT EXISTS idx_hostselections_host ON hostselections(host_id);`, + `CREATE INDEX IF NOT EXISTS idx_resolvechain_precedence ON resolvechain(precedence);`, + } + + // 逐条执行SQL + for i, sql := range sqls { + if strings.TrimSpace(sql) == "" { + continue + } + + log.Debugf("执行SQL[%d]: %s", i, strings.TrimSpace(strings.Split(sql, "\n")[0])) + + _, err := tx.Exec(sql) + if err != nil { + tx.Rollback() + return fmt.Errorf("执行SQL[%d]失败: %v\nSQL: %s", i, err, sql) + } + } + + // 提交事务 + if err := tx.Commit(); err != nil { + return fmt.Errorf("提交事务失败: %v", err) + } + + log.Info("数据库表结构创建成功") + + // 插入默认数据 + return db.insertDefaultData() +} + +// insertDefaultData 插入默认数据 +func (db *DB) insertDefaultData() error { + log.Debug("插入默认数据...") + // 默认类别 + categories := []string{"global", "host", "os", "appliance", "network"} + for _, cat := range categories { + _, err := db.engine.Exec( + "INSERT OR IGNORE INTO categories (name) VALUES (?)", + cat, + ) + if err != nil { + return err + } + } + + log.Debug("插入默认类别索引...") + // 默认类别索引 + catIndexes := []struct { + catName string + idxName string + }{ + {"global", "global"}, + {"os", "linux"}, + {"network", "private"}, + } + + for _, ci := range catIndexes { + _, err := db.engine.Exec(` + INSERT OR IGNORE INTO catindexes (name, category_id) + SELECT ?, id FROM categories WHERE name = ? + `, ci.idxName, ci.catName) + if err != nil { + return err + } + } + + log.Debug("插入默认解析链优先级...") + // 默认解析链优先级 + precedence := []struct { + catName string + level int + }{ + {"global", 1}, + {"os", 2}, + {"appliance", 3}, + {"host", 4}, + {"network", 5}, + } + + for _, p := range precedence { + _, err := db.engine.Exec(` + INSERT OR IGNORE INTO resolvechain (category_id, precedence) + SELECT id, ? FROM categories WHERE name = ? + `, p.level, p.catName) + if err != nil { + return err + } + } + + log.Debug("插入默认设备类型...") + // 默认设备类型 + appliances := []struct { + name string + desc string + typ string + }{ + {"frontend", "管理节点", "master"}, + {"compute", "计算节点", "compute"}, + {"login", "登录节点", "login"}, + {"storage", "存储节点", "storage"}, + } + + for _, a := range appliances { + _, err := db.engine.Exec( + "INSERT OR IGNORE INTO appliances (name, description, node_type) VALUES (?, ?, ?)", + a.name, a.desc, a.typ, + ) + if err != nil { + return err + } + } + + log.Debug("插入默认数据完成...") + return nil +} + +// ==================== 核心查询方法 ==================== + +// Execute 执行SQL语句 - 对应Rocks的execute() +func (db *DB) Execute(query string, args ...interface{}) (int64, error) { + db.mu.RLock() + conn := db.conn + verbose := db.verbose + db.mu.RUnlock() + + if conn == nil { + return 0, fmt.Errorf("没有活动数据库连接") + } + + if verbose { + log.Debugf("执行SQL: %s %v", query, args) + } + + // 判断SQL类型 + upperQuery := strings.ToUpper(strings.TrimSpace(query)) + isSelect := strings.HasPrefix(upperQuery, "SELECT") + + if isSelect { + // SELECT 查询使用 QueryContext + rows, err := conn.QueryContext(context.Background(), query, args...) + if err != nil { + // 尝试重连一次 + db.RenewConnection() + db.mu.RLock() + conn = db.conn + db.mu.RUnlock() + rows, err = conn.QueryContext(context.Background(), query, args...) + } + + if err != nil { + return 0, err + } + + // 关闭旧结果 + db.mu.Lock() + if db.results != nil { + db.results.Close() + } + db.results = rows + db.mu.Unlock() + + return 0, nil + } else { + // INSERT/UPDATE/DELETE 使用 Exec(自动提交) + result, err := conn.ExecContext(context.Background(), query, args...) + if err != nil { + // 尝试重连一次 + db.RenewConnection() + db.mu.RLock() + conn = db.conn + db.mu.RUnlock() + result, err = conn.ExecContext(context.Background(), query, args...) + } + + if err != nil { + return 0, err + } + + // 获取影响行数 + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, err + } + + if verbose { + log.Debugf("影响行数: %d", rowsAffected) + } + + return rowsAffected, nil + } +} + +// FetchOne 获取一行 - 对应Rocks的fetchone() +// 返回map[string]interface{}格式,key为列名 +func (db *DB) FetchOne() (map[string]interface{}, error) { + db.mu.RLock() + results := db.results + db.mu.RUnlock() + + if results == nil { + return nil, nil + } + + if !results.Next() { + return nil, nil + } + + columns, err := results.Columns() + if err != nil { + return nil, err + } + + values := make([]interface{}, len(columns)) + scanArgs := make([]interface{}, len(columns)) + for i := range values { + scanArgs[i] = &values[i] + } + + err = results.Scan(scanArgs...) + if err != nil { + return nil, err + } + + row := make(map[string]interface{}) + for i, col := range columns { + val := values[i] + if b, ok := val.([]byte); ok { + row[col] = string(b) + } else { + row[col] = val + } + } + + return row, nil +} + +// FetchAll 获取所有行 - 对应Rocks的fetchall() +// 返回[]map[string]interface{}格式 +func (db *DB) FetchAll() ([]map[string]interface{}, error) { + db.mu.RLock() + results := db.results + db.mu.RUnlock() + + if results == nil { + return nil, nil + } + + columns, err := results.Columns() + if err != nil { + return nil, err + } + + var rows []map[string]interface{} + + for results.Next() { + values := make([]interface{}, len(columns)) + scanArgs := make([]interface{}, len(columns)) + for i := range values { + scanArgs[i] = &values[i] + } + + err = results.Scan(scanArgs...) + if err != nil { + return nil, err + } + + row := make(map[string]interface{}) + for i, col := range columns { + val := values[i] + if b, ok := val.([]byte); ok { + row[col] = string(b) + } else { + row[col] = val + } + } + + rows = append(rows, row) + } + + return rows, nil +} + +// ==================== 连接维护 ==================== + +// RenewConnection 续期连接 +func (db *DB) RenewConnection() error { + db.mu.Lock() + defer db.mu.Unlock() + + if db.conn != nil { + db.conn.Close() + } + + conn, err := db.engine.Conn(context.Background()) + if err != nil { + return err + } + db.conn = conn + return nil +} + +// Close 关闭连接 +func (db *DB) Close() error { + db.mu.Lock() + defer db.mu.Unlock() + + if db.results != nil { + db.results.Close() + db.results = nil + } + if db.conn != nil { + db.conn.Close() + db.conn = nil + } + if db.engine != nil { + return db.engine.Close() + } + return nil +} + +// CloseConnection 只关闭当前连接,不关闭连接池 +func (db *DB) CloseConnection() error { + db.mu.Lock() + defer db.mu.Unlock() + + if db.results != nil { + db.results.Close() + db.results = nil + } + if db.conn != nil { + db.conn.Close() + db.conn = nil + } + return nil +} + +// ==================== 单例模式 ==================== + +var ( + instanceConfigured bool + instanceDBPath string + instanceDBName string +) + +func GetInstance() (*DB, error) { + return GetInstanceWithConfig("", "") +} + +func GetInstanceWithConfig(dbPath, dbName string) (*DB, error) { + var err error + once.Do(func() { + globalDB = NewDB() + log.Debug("创建数据库实例...") + globalDB.SetDBUsername(globalDB.GetDBUsername()) + + if dbPath != "" { + globalDB.SetDBPath(dbPath) + log.Debugf("设置数据库路径: %s", dbPath) + } + if dbName != "" { + globalDB.SetDBName(dbName) + log.Debugf("设置数据库名称: %s", dbName) + } + + instanceConfigured = (dbPath != "" || dbName != "") + if dbPath != "" { + instanceDBPath = dbPath + } + if dbName != "" { + instanceDBName = dbName + } + + err = globalDB.Connect() + }) + return globalDB, err +} + +func IsInstanceConfigured() bool { + return instanceConfigured +} + +func GetInstanceConfig() (dbPath, dbName string) { + return instanceDBPath, instanceDBName +} diff --git a/internal/db/helper.go b/internal/db/helper.go new file mode 100644 index 0000000..b6144e8 --- /dev/null +++ b/internal/db/helper.go @@ -0,0 +1,624 @@ +package db + +import ( + "fmt" + "net" + "os" + "strings" + "sync" +) + +/* + // 获取数据库实例 + database, err := db.GetInstance() + if err != nil { + log.Fatal(err) + } + defer database.Close() + + // 创建Helper + helper, _ := db.NewDBHelper() + + // 执行查询 + helper.Execute("SELECT * FROM nodes WHERE rack = ?", 1) + + // 获取一行 + row, _ := helper.FetchOne() + if row != nil { + log.Infof("节点: %v", row["name"]) + } + + // 获取所有行 + rows, _ := helper.FetchAll() + log.Infof("共 %d 个节点", len(rows)) + + // 使用Helper方法 + hostname, _ := helper.GetHostname("192.168.1.1") + log.Infof("解析主机名: %s", hostname) + + // 设置属性 + helper.SetCategoryAttr("global", "global", "Kickstart_PrivateHostname", "sunhpc-master") + + // 获取属性 + val := helper.GetCategoryAttr("global", "global", "Kickstart_PrivateHostname") + log.Infof("前端主机名: %s", val) +*/ + +const ( + attrPostfix = "_old" +) + +// DBHelper DatabaseHelper类 - 继承DB,扩展业务方法 +type DBHelper struct { + *DB + appliancesList []string + frontendName string + cacheAttrs sync.Map +} + +func NewDBHelper() (*DBHelper, error) { + db, err := GetInstance() + if err != nil { + return nil, err + } + return &DBHelper{ + DB: db, + appliancesList: nil, + frontendName: "", + }, nil +} + +// ==================== 节点查询 ==================== + +// GetListHostnames 获取所有主机名列表 +func (h *DBHelper) GetListHostnames() ([]string, error) { + _, err := h.Execute("SELECT name FROM nodes ORDER BY name") + if err != nil { + return nil, err + } + + rows, err := h.FetchAll() + if err != nil { + return nil, err + } + + var names []string + for _, row := range rows { + if name, ok := row["name"]; ok { + names = append(names, name.(string)) + } + } + return names, nil +} + +// GetNodesFromNames 从名称列表获取节点 +func (h *DBHelper) GetNodesFromNames(names []string, managedOnly bool) ([]map[string]interface{}, error) { + // 如果没有提供名称,返回所有节点 + if len(names) == 0 { + query := "SELECT * FROM nodes" + if managedOnly { + query = ` + SELECT n.* FROM nodes n + JOIN node_attrs a ON n.id = a.node_id + WHERE a.attr = 'managed' AND a.value = 'true' + ` + } + + _, err := h.Execute(query) + if err != nil { + return nil, err + } + return h.FetchAll() + } + + // 构建查询条件 + conditions := []string{} + args := []interface{}{} + + for _, name := range names { + if strings.HasPrefix(name, "select ") { + conditions = append(conditions, fmt.Sprintf("name IN (%s)", name[7:])) + + } else if strings.Contains(name, "%") { + conditions = append(conditions, "name LIKE ?") + args = append(args, name) + + } else if strings.HasPrefix(name, "rack") { + rackNum := strings.TrimPrefix(name, "rack") + conditions = append(conditions, "rack = ?") + args = append(args, rackNum) + + } else if h.IsApplianceName(name) { + conditions = append(conditions, `id IN ( + SELECT node_id FROM node_attrs + WHERE attr = 'appliance' AND value = ? + )`) + args = append(args, name) + + } else { + hostname, err := h.GetHostname(name) + if err == nil { + conditions = append(conditions, "name = ?") + args = append(args, hostname) + } + } + } + + if len(conditions) == 0 { + return []map[string]interface{}{}, nil + } + + query := "SELECT * FROM nodes WHERE " + strings.Join(conditions, " OR ") + _, err := h.Execute(query, args...) + if err != nil { + return nil, err + } + + nodes, err := h.FetchAll() + if err != nil { + return nil, err + } + + // 过滤受管节点 + if managedOnly { + var managed []map[string]interface{} + for _, node := range nodes { + val := h.GetHostAttr(node["name"].(string), "managed") + if val == "true" { + managed = append(managed, node) + } + } + return managed, nil + } + + return nodes, nil +} + +// ==================== 设备类型 ==================== + +// GetAppliancesListText 获取所有设备类型名称 +func (h *DBHelper) GetAppliancesListText() []string { + if h.appliancesList != nil { + return h.appliancesList + } + + _, err := h.Execute("SELECT DISTINCT value FROM node_attrs WHERE attr = 'appliance'") + if err != nil { + return []string{} + } + + rows, _ := h.FetchAll() + var apps []string + for _, row := range rows { + if val, ok := row["value"]; ok { + apps = append(apps, val.(string)) + } + } + + h.appliancesList = apps + return apps +} + +// IsApplianceName 检查是否为设备类型名称 +func (h *DBHelper) IsApplianceName(name string) bool { + for _, app := range h.GetAppliancesListText() { + if app == name { + return true + } + } + return false +} + +// ==================== 主机名解析 ==================== + +// GetHostname 规范化主机名 - 完全参考Rocks实现 +func (h *DBHelper) GetHostname(hostname string) (string, error) { + // 如果hostname为空,使用系统主机名 + if hostname == "" { + hostname, _ = os.Hostname() + hostname = strings.Split(hostname, ".")[0] + return h.GetHostname(hostname) + } + + // 1. 直接在nodes表中查找 + _, err := h.Execute("SELECT * FROM nodes WHERE name = ?", hostname) + if err == nil { + row, _ := h.FetchOne() + if row != nil { + return hostname, nil + } + } + + // 2. 尝试IP地址反向解析 + addr := net.ParseIP(hostname) + if addr != nil { + names, err := net.LookupAddr(hostname) + if err == nil && len(names) > 0 { + return h.GetHostname(strings.Split(names[0], ".")[0]) + } + } + + // 3. 在networks表中查找IP + if addr != nil { + _, err := h.Execute(` + SELECT n.name FROM nodes n + JOIN networks net ON n.id = net.node_id + WHERE net.ip = ? + `, addr.String()) + if err == nil { + row, _ := h.FetchOne() + if row != nil { + return row["name"].(string), nil + } + } + } + + // 4. 尝试MAC地址 + mac := strings.ReplaceAll(hostname, "-", ":") + _, err = h.Execute(` + SELECT n.name FROM nodes n + JOIN networks net ON n.id = net.node_id + WHERE net.mac = ? + `, mac) + if err == nil { + row, _ := h.FetchOne() + if row != nil { + return row["name"].(string), nil + } + } + + // 5. 检查别名 + _, err = h.Execute(` + SELECT n.name FROM nodes n + JOIN aliases a ON n.id = a.node_id + WHERE a.name = ? + `, hostname) + if err == nil { + row, _ := h.FetchOne() + if row != nil { + return row["name"].(string), nil + } + } + + // 6. 尝试FQDN + if strings.Contains(hostname, ".") { + parts := strings.Split(hostname, ".") + name := parts[0] + domain := strings.Join(parts[1:], ".") + + _, err := h.Execute(` + SELECT n.name FROM nodes n + JOIN networks net ON n.id = net.node_id + JOIN subnets s ON net.subnet_id = s.id + WHERE s.dns_zone = ? AND (net.name = ? OR n.name = ?) + `, domain, name, name) + if err == nil { + row, _ := h.FetchOne() + if row != nil { + return row["name"].(string), nil + } + } + } + + // 7. 如果以上都失败,抛出异常 + return "", fmt.Errorf("无法解析主机名: %s", hostname) +} + +// CheckHostnameValidity 检查主机名有效性 +func (h *DBHelper) CheckHostnameValidity(hostname string) error { + // 不能包含点 + if strings.Contains(hostname, ".") { + return fmt.Errorf("主机名 %s 不能包含点号", hostname) + } + + // 不能是rack<数字>格式 + if strings.HasPrefix(hostname, "rack") { + num := strings.TrimPrefix(hostname, "rack") + if _, err := fmt.Sscanf(num, "%d", new(int)); err == nil { + return fmt.Errorf("主机名 %s 不能是rack<数字>格式", hostname) + } + } + + // 不能是设备类型名称 + if h.IsApplianceName(hostname) { + return fmt.Errorf("主机名 %s 不能与设备类型名称相同", hostname) + } + + // 检查是否已存在 + _, err := h.GetHostname(hostname) + if err == nil { + return fmt.Errorf("节点 %s 已存在", hostname) + } + + return nil +} + +// ==================== 前端节点 ==================== + +// GetFrontendName 获取前端节点名称 +func (h *DBHelper) GetFrontendName() string { + if h.frontendName != "" { + return h.frontendName + } + + name := h.GetCategoryAttr("global", "global", "Kickstart_PrivateHostname") + if name != "" { + h.frontendName = name + } + return h.frontendName +} + +// ==================== 属性管理 ==================== + +// GetCategoryIndex 获取类别索引 +func (h *DBHelper) GetCategoryIndex(categoryName, categoryIndex string) (map[string]interface{}, map[string]interface{}, error) { + // 查询类别和索引 + _, err := h.Execute(` + SELECT c.id as cid, c.name as cname, i.id as iid, i.name as iname + FROM categories c + JOIN catindexes i ON c.id = i.category_id + WHERE c.name = ? AND i.name = ? + `, categoryName, categoryIndex) + + if err == nil { + row, _ := h.FetchOne() + if row != nil { + category := map[string]interface{}{ + "id": row["cid"], + "name": row["cname"], + } + catindex := map[string]interface{}{ + "id": row["iid"], + "name": row["iname"], + "category_id": row["cid"], + } + return category, catindex, nil + } + } + + // 不存在则创建 + // 创建类别 + _, err = h.Execute("INSERT INTO categories (name) VALUES (?)", categoryName) + if err != nil { + return nil, nil, err + } + + var catID int64 + h.Execute("SELECT last_insert_rowid()") + row, _ := h.FetchOne() + if row != nil { + catID = row["last_insert_rowid()"].(int64) + } + + // 创建索引 + _, err = h.Execute( + "INSERT INTO catindexes (name, category_id) VALUES (?, ?)", + categoryIndex, catID, + ) + if err != nil { + return nil, nil, err + } + + h.Execute("SELECT last_insert_rowid()") + row, _ = h.FetchOne() + var idxID int64 + if row != nil { + idxID = row["last_insert_rowid()"].(int64) + } + + category := map[string]interface{}{ + "id": catID, + "name": categoryName, + } + catindex := map[string]interface{}{ + "id": idxID, + "name": categoryIndex, + "category_id": catID, + } + + return category, catindex, nil +} + +// SetCategoryAttr 设置类别属性 +func (h *DBHelper) SetCategoryAttr(categoryName, catindexName, attr, value string) error { + cat, catindex, err := h.GetCategoryIndex(categoryName, catindexName) + if err != nil { + return err + } + + // 查询现有属性 + _, err = h.Execute(` + SELECT id, value FROM attributes + WHERE attr = ? AND category_id = ? AND catindex_id = ? + `, attr, cat["id"], catindex["id"]) + + if err == nil { + row, _ := h.FetchOne() + if row != nil { + // 更新现有属性 + oldValue := row["value"] + attrID := row["id"] + + _, err = h.Execute( + "UPDATE attributes SET value = ? WHERE id = ?", + value, attrID, + ) + if err != nil { + return err + } + + // 保存旧值 + if !strings.HasSuffix(attr, attrPostfix) { + h.SetCategoryAttr(categoryName, catindexName, attr+attrPostfix, oldValue.(string)) + } + return nil + } + } + + // 创建新属性 + _, err = h.Execute(` + INSERT INTO attributes (attr, value, category_id, catindex_id) + VALUES (?, ?, ?, ?) + `, attr, value, cat["id"], catindex["id"]) + + return err +} + +// GetCategoryAttr 获取类别属性 +func (h *DBHelper) GetCategoryAttr(categoryName, catindexName, attrName string) string { + cat, catindex, err := h.GetCategoryIndex(categoryName, catindexName) + if err != nil { + return "" + } + + _, err = h.Execute(` + SELECT value FROM attributes + WHERE attr = ? AND category_id = ? AND catindex_id = ? + `, attrName, cat["id"], catindex["id"]) + + if err != nil { + return "" + } + + row, _ := h.FetchOne() + if row == nil { + return "" + } + + return row["value"].(string) +} + +// RemoveCategoryAttr 移除类别属性 +func (h *DBHelper) RemoveCategoryAttr(categoryName, catindexName, attrName string) error { + cat, catindex, err := h.GetCategoryIndex(categoryName, catindexName) + if err != nil { + return err + } + + _, err = h.Execute(` + DELETE FROM attributes + WHERE attr = ? AND category_id = ? AND catindex_id = ? + `, attrName, cat["id"], catindex["id"]) + + if err != nil { + return err + } + + // 同时删除对应的_old属性 + _, _ = h.Execute(` + DELETE FROM attributes + WHERE attr = ? AND category_id = ? AND catindex_id = ? + `, attrName+attrPostfix, cat["id"], catindex["id"]) + + return nil +} + +// ==================== 主机属性 ==================== + +// GetHostAttr 获取主机属性 +func (h *DBHelper) GetHostAttr(hostname, attr string) string { + // 先从节点直接属性查询 + _, err := h.Execute(` + SELECT value FROM node_attrs + WHERE node_id = (SELECT id FROM nodes WHERE name = ?) + AND attr = ? + `, hostname, attr) + + if err == nil { + row, _ := h.FetchOne() + if row != nil { + return row["value"].(string) + } + } + + // 使用Rocks的属性解析链查询 + query := ` + SELECT a.value FROM attributes a + JOIN resolvechain r ON a.category_id = r.category_id + JOIN hostselections h ON a.category_id = h.category_id + AND a.catindex_id = h.selection + WHERE h.host_id = (SELECT id FROM nodes WHERE name = ?) + AND a.attr = ? + ORDER BY r.precedence DESC + LIMIT 1 + ` + + _, err = h.Execute(query, hostname, attr) + if err != nil { + return "" + } + + row, _ := h.FetchOne() + if row == nil { + return "" + } + + return row["value"].(string) +} + +// GetHostAttrs 获取主机所有属性 +func (h *DBHelper) GetHostAttrs(hostname string, showSource bool) map[string]interface{} { + attrs := make(map[string]interface{}) + + // 获取节点基本信息 + _, err := h.Execute(` + SELECT n.id, n.name, n.rack, n.rank, m.name as membership, a.name as appliance + FROM nodes n + LEFT JOIN memberships m ON n.membership_id = m.id + LEFT JOIN appliances a ON m.appliance_id = a.id + WHERE n.name = ? + `, hostname) + + if err == nil { + row, _ := h.FetchOne() + if row != nil { + if showSource { + attrs["hostname"] = []interface{}{row["name"], "I"} + attrs["rack"] = []interface{}{row["rack"], "I"} + attrs["rank"] = []interface{}{row["rank"], "I"} + attrs["appliance"] = []interface{}{row["appliance"], "I"} + attrs["membership"] = []interface{}{row["membership"], "I"} + } else { + attrs["hostname"] = row["name"] + attrs["rack"] = row["rack"] + attrs["rank"] = row["rank"] + attrs["appliance"] = row["appliance"] + attrs["membership"] = row["membership"] + } + } + } + + // 获取所有属性 + query := ` + SELECT a.attr, a.value, + CASE + WHEN h.host_id IS NOT NULL THEN 'H' + ELSE UPPER(SUBSTR(c.name, 1, 1)) + END as source + FROM attributes a + JOIN categories c ON a.category_id = c.id + LEFT JOIN hostselections h ON a.category_id = h.category_id + AND a.catindex_id = h.selection + AND h.host_id = (SELECT id FROM nodes WHERE name = ?) + UNION + SELECT attr, value, 'N' as source + FROM node_attrs + WHERE node_id = (SELECT id FROM nodes WHERE name = ?) + ` + + _, err = h.Execute(query, hostname, hostname) + if err == nil { + rows, _ := h.FetchAll() + for _, row := range rows { + attr := row["attr"].(string) + value := row["value"] + if showSource { + attrs[attr] = []interface{}{value, row["source"]} + } else { + attrs[attr] = value + } + } + } + + return attrs +} diff --git a/internal/log/logger.go b/internal/log/logger.go new file mode 100644 index 0000000..51391bf --- /dev/null +++ b/internal/log/logger.go @@ -0,0 +1,344 @@ +package log + +import ( + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "sync" + "time" + + "github.com/fatih/color" +) + +// 日志级别 +type Level int + +const ( + DebugLevel Level = iota + InfoLevel + WarnLevel + ErrorLevel + FatalLevel +) + +// 级别名称 +var levelNames = map[Level]string{ + DebugLevel: "DEBUG", + InfoLevel: "INFO", + WarnLevel: "WARN", + ErrorLevel: "ERROR", + FatalLevel: "FATAL", +} + +// 级别简写 +var levelShort = map[Level]string{ + DebugLevel: "[d]", + InfoLevel: "[i]", + WarnLevel: "[w]", + ErrorLevel: "[e]", + FatalLevel: "[f]", +} + +// 级别颜色 +var levelColor = map[Level]func(format string, a ...interface{}) string{ + DebugLevel: color.CyanString, // 青色 + InfoLevel: color.GreenString, // 绿色 + WarnLevel: color.YellowString, // 黄色 + ErrorLevel: color.RedString, // 红色 + FatalLevel: color.MagentaString, // 品红 +} + +// Logger 日志器结构体 +type Logger struct { + mu sync.Mutex + consoleOut io.Writer // 控制台输出 + fileOut io.Writer // 文件输出 + minLevel Level // 最小输出级别 + showColor bool // 是否显示颜色 + showCaller bool // 是否显示调用者信息 + callerSkip int // 调用者跳过的层级 + timeFormat string // 时间格式 +} + +// 默认日志器实例 +var defaultLogger *Logger + +const ( + defaultTimeFormat = "2006-01-02 15:04:05" + logFile = "/var/log/sunhpc/sunhpc.log" +) + +// Init 初始化日志系统 +func Init(verbose bool) { + // 确保日志目录存在 + if err := os.MkdirAll(filepath.Dir(logFile), 0755); err != nil { + fmt.Fprintf(os.Stderr, "创建日志目录失败: %v\n", err) + os.Exit(1) + } + + // 打开日志文件 + file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + fmt.Fprintf(os.Stderr, "打开日志文件失败: %v\n", err) + os.Exit(1) + } + + // 控制台输出 + consoleOut := os.Stdout + + // 创建日志器 + defaultLogger = &Logger{ + consoleOut: consoleOut, + fileOut: file, + minLevel: InfoLevel, + showColor: true, + showCaller: false, + callerSkip: 2, + timeFormat: defaultTimeFormat, + } + + // 详细模式下显示调试信息 + if verbose { + defaultLogger.minLevel = DebugLevel + defaultLogger.showCaller = true + } + + // 初始化颜色支持 + if runtime.GOOS == "windows" { + color.NoColor = false + } +} + +// log 核心日志输出方法 +func (l *Logger) log(level Level, format string, args ...interface{}) { + if level < l.minLevel { + return + } + + l.mu.Lock() + defer l.mu.Unlock() + + // 生成时间戳 + timestamp := time.Now().Format(l.timeFormat) + + // 获取调用者信息 + caller := "" + if l.showCaller { + _, file, line, ok := runtime.Caller(l.callerSkip) + if ok { + // 只保留文件名和行号 + file = filepath.Base(file) + caller = fmt.Sprintf(" %s:%d", file, line) + } + } + + // 格式化消息 + var message string + if format == "" { + message = fmt.Sprint(args...) + } else { + message = fmt.Sprintf(format, args...) + } + + // ---- 控制台输出(带颜色和简写)---- + if l.consoleOut != nil { + // 获取级别简写 + shortPrefix := levelShort[level] + + // 构建控制台行 + var consoleLine string + + if l.showColor { + // 带颜色输出 - 简写有颜色,时间戳灰色 + colorFunc := levelColor[level] + consoleLine = fmt.Sprintf("%s %s %s", + color.HiBlackString(timestamp), // 时间戳灰色 + colorFunc(shortPrefix), // 级别简写彩色 + message) // 消息普通颜色 + } else { + // 不带颜色输出 + consoleLine = fmt.Sprintf("%s %s %s", + timestamp, + shortPrefix, + message) + } + + // 添加调用者信息(灰色) + if caller != "" { + if l.showColor { + consoleLine += fmt.Sprintf(" %s", color.HiBlackString(caller)) + } else { + consoleLine += fmt.Sprintf(" %s", caller) + } + } + + fmt.Fprintln(l.consoleOut, consoleLine) + } + + // ---- 文件输出(完整格式)---- + if l.fileOut != nil { + // 获取级别全名 + levelName := levelNames[level] + + // 文件使用完整格式:时间 [级别] 消息 调用者 + fileLine := fmt.Sprintf("%s [%s] %s%s\n", + timestamp, + levelName, + message, + caller) + + fmt.Fprint(l.fileOut, fileLine) + } + + // 致命错误退出程序 + if level == FatalLevel { + os.Exit(1) + } +} + +// 全局日志函数 + +// Debug 调试日志 +func Debug(args ...interface{}) { + if defaultLogger != nil { + defaultLogger.log(DebugLevel, "", args...) + } +} + +// Debugf 格式化调试日志 +func Debugf(format string, args ...interface{}) { + if defaultLogger != nil { + defaultLogger.log(DebugLevel, format, args...) + } +} + +// Info 信息日志 +func Info(args ...interface{}) { + if defaultLogger != nil { + defaultLogger.log(InfoLevel, "", args...) + } +} + +// Infof 格式化信息日志 +func Infof(format string, args ...interface{}) { + if defaultLogger != nil { + defaultLogger.log(InfoLevel, format, args...) + } +} + +// Warn 警告日志 +func Warn(args ...interface{}) { + if defaultLogger != nil { + defaultLogger.log(WarnLevel, "", args...) + } +} + +// Warnf 格式化警告日志 +func Warnf(format string, args ...interface{}) { + if defaultLogger != nil { + defaultLogger.log(WarnLevel, format, args...) + } +} + +// Error 错误日志 +func Error(args ...interface{}) { + if defaultLogger != nil { + defaultLogger.log(ErrorLevel, "", args...) + } +} + +// Errorf 格式化错误日志 +func Errorf(format string, args ...interface{}) { + if defaultLogger != nil { + defaultLogger.log(ErrorLevel, format, args...) + } +} + +// Fatal 致命错误日志,输出后退出程序 +func Fatal(args ...interface{}) { + if defaultLogger != nil { + defaultLogger.log(FatalLevel, "", args...) + } +} + +// Fatalf 格式化致命错误日志,输出后退出程序 +func Fatalf(format string, args ...interface{}) { + if defaultLogger != nil { + defaultLogger.log(FatalLevel, format, args...) + } +} + +// Writer 返回一个 io.Writer,可将子命令的输出写入日志(Debug级别) +func Writer() *io.PipeWriter { + r, w := io.Pipe() + go func() { + buf := make([]byte, 1024) + for { + n, err := r.Read(buf) + if n > 0 { + Debug(string(buf[:n])) + } + if err != nil { + break + } + } + }() + return w +} + +// SetLevel 设置日志级别 +func SetLevel(level Level) { + if defaultLogger != nil { + defaultLogger.mu.Lock() + defer defaultLogger.mu.Unlock() + defaultLogger.minLevel = level + } +} + +// EnableColor 启用/禁用颜色输出 +func EnableColor(enable bool) { + if defaultLogger != nil { + defaultLogger.mu.Lock() + defer defaultLogger.mu.Unlock() + defaultLogger.showColor = enable + } +} + +// EnableCaller 启用/禁用调用者信息 +func EnableCaller(enable bool) { + if defaultLogger != nil { + defaultLogger.mu.Lock() + defer defaultLogger.mu.Unlock() + defaultLogger.showCaller = enable + } +} + +// SetTimeFormat 设置时间格式 +func SetTimeFormat(format string) { + if defaultLogger != nil { + defaultLogger.mu.Lock() + defer defaultLogger.mu.Unlock() + defaultLogger.timeFormat = format + } +} + +// Sync 同步日志文件 +func Sync() { + if defaultLogger != nil && defaultLogger.fileOut != nil { + if f, ok := defaultLogger.fileOut.(*os.File); ok { + f.Sync() + } + } +} + +// Close 关闭日志文件 +func Close() error { + if defaultLogger != nil && defaultLogger.fileOut != nil { + if f, ok := defaultLogger.fileOut.(*os.File); ok { + return f.Close() + } + } + return nil +} diff --git a/internal/service/manager.go b/internal/service/manager.go new file mode 100644 index 0000000..9e6f8c6 --- /dev/null +++ b/internal/service/manager.go @@ -0,0 +1,25 @@ +package service + +import ( + "fmt" + "sunhpc/internal/config" + "sunhpc/internal/log" + "sunhpc/internal/template" +) + +func Deploy(cfg *config.ServicesConfig) error { + // 示例:使用模板部署 DHCPD + if cfg.DHCPD.Enabled { + log.Info("部署 DHCPD 服务...") + // 从模板渲染配置文件 + err := template.RenderAndExecute("dhcpd.conf.tmpl", map[string]interface{}{ + "Subnet": "192.168.1.0", + "Netmask": "255.255.255.0", + }) + if err != nil { + return fmt.Errorf("DHCPD 配置失败: %v", err) + } + // 实际部署逻辑(启动服务等)... + } + return nil +} diff --git a/internal/soft/binary.go b/internal/soft/binary.go new file mode 100644 index 0000000..ca19015 --- /dev/null +++ b/internal/soft/binary.go @@ -0,0 +1,33 @@ +package soft + +import ( + "fmt" + "os" + "os/exec" + "strings" +) + +// extractBinary 解压二进制压缩包到目标目录 +func extractBinary(binPath, destDir string) error { + // 确保目标目录存在 + if err := os.MkdirAll(destDir, 0755); err != nil { + return err + } + + // 根据扩展名选择解压命令 + var cmd *exec.Cmd + switch { + case strings.HasSuffix(binPath, ".tar.gz"), strings.HasSuffix(binPath, ".tgz"): + cmd = exec.Command("tar", "xzf", binPath, "-C", destDir) + case strings.HasSuffix(binPath, ".tar.bz2"): + cmd = exec.Command("tar", "xjf", binPath, "-C", destDir) + case strings.HasSuffix(binPath, ".zip"): + cmd = exec.Command("unzip", binPath, "-d", destDir) + default: + return fmt.Errorf("不支持的压缩格式: %s", binPath) + } + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} diff --git a/internal/soft/manager.go b/internal/soft/manager.go new file mode 100644 index 0000000..1e0ba21 --- /dev/null +++ b/internal/soft/manager.go @@ -0,0 +1,49 @@ +package soft + +import ( + "sunhpc/internal/log" +) + +// InstallContext 安装上下文,包含所有命令行参数 +type InstallContext struct { + Force bool // 强制安装 + DryRun bool // 干运行模式 + KeepSource bool // 保留源码文件 + Jobs int // 编译线程数 + Offline bool // 离线模式 +} + +// InstallFromSource 从源码编译安装 +func InstallFromSource(name, srcPath, prefix, version string, ctx *InstallContext) error { + log.Infof("正在从源码安装 %s,路径: %s", name, srcPath) + if ctx != nil && ctx.DryRun { + log.Infof("[干运行] 将执行: configure && make -j%d && make install", ctx.Jobs) + return nil + } + // TODO: 实现具体逻辑:下载、解压、./configure、make、make install + log.Info("源码安装模拟完成(需实现具体步骤)") + return nil +} + +// InstallFromBinary 从二进制压缩包安装 +func InstallFromBinary(name, binPath, prefix string, ctx *InstallContext) error { + log.Infof("正在安装二进制包 %s,路径: %s", name, binPath) + if ctx != nil && ctx.DryRun { + log.Infof("[干运行] 将解压 %s 到 %s", binPath, prefix) + return nil + } + // TODO: 解压到 prefix + log.Info("二进制安装模拟完成(需实现具体步骤)") + return nil +} + +// InstallFromPackage 通过系统包管理器安装 +func InstallFromPackage(name, pkgType string, ctx *InstallContext) error { + log.Infof("正在通过包管理器安装 %s (%s)", name, pkgType) + if ctx != nil && ctx.DryRun { + log.Infof("[干运行] 将执行包管理器安装 %s", name) + return nil + } + // 具体实现在下面的 package.go 中 + return installViaPackageManager(name, pkgType) +} diff --git a/internal/soft/package.go b/internal/soft/package.go new file mode 100644 index 0000000..ed3aa2f --- /dev/null +++ b/internal/soft/package.go @@ -0,0 +1,38 @@ +package soft + +import ( + "fmt" + "os" + "os/exec" + "sunhpc/internal/log" + "sunhpc/pkg/utils" +) + +// installViaPackageManager 使用系统包管理器安装软件 +func installViaPackageManager(name, pkgType string) error { + var cmd *exec.Cmd + switch pkgType { + case "rpm": + // RHEL/CentOS + if utils.CommandExists("yum") { + cmd = exec.Command("yum", "install", "-y", name) + } else if utils.CommandExists("dnf") { + cmd = exec.Command("dnf", "install", "-y", name) + } else { + return fmt.Errorf("未找到 yum 或 dnf 包管理器") + } + case "deb": + // Debian/Ubuntu + if !utils.CommandExists("apt-get") { + return fmt.Errorf("未找到 apt-get 包管理器") + } + cmd = exec.Command("apt-get", "install", "-y", name) + default: + return fmt.Errorf("不支持的包类型: %s", pkgType) + } + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + log.Infof("执行命令: %s", cmd.String()) + return cmd.Run() +} diff --git a/internal/soft/source.go b/internal/soft/source.go new file mode 100644 index 0000000..fc381db --- /dev/null +++ b/internal/soft/source.go @@ -0,0 +1,48 @@ +package soft + +import ( + "fmt" + "os" + "os/exec" + "sunhpc/internal/log" + "sunhpc/pkg/utils" +) + +// compileFromSource 通用源码编译流程 +func compileFromSource(srcDir, prefix string, jobs int) error { + // 切换到源码目录 + if err := os.Chdir(srcDir); err != nil { + return fmt.Errorf("进入源码目录失败: %v", err) + } + + // 检测 configure 脚本是否存在 + if utils.FileExists("./configure") { + log.Debug("执行 configure ...") + cmd := exec.Command("./configure", "--prefix="+prefix) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("configure 失败: %v", err) + } + } + + // make + log.Debugf("执行 make -j%d ...", jobs) + cmd := exec.Command("make", fmt.Sprintf("-j%d", jobs)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("make 失败: %v", err) + } + + // make install + log.Debug("执行 make install ...") + cmd = exec.Command("make", "install") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("make install 失败: %v", err) + } + + return nil +} diff --git a/internal/system/hostname.go b/internal/system/hostname.go new file mode 100644 index 0000000..bb6f8ef --- /dev/null +++ b/internal/system/hostname.go @@ -0,0 +1,97 @@ +package system + +import ( + "fmt" + "os" + "os/exec" + "strings" +) + +// SetHostname 设置系统主机名 +// 参数: hostname - 目标主机名 +// 返回: error - 如果设置失败返回错误信息 +func SetHostname(hostname string) error { + if hostname == "" { + return nil // 空值跳过,不报错 + } + + // 检查是否已有相同主机名 + current, err := os.Hostname() + if err == nil && current == hostname { + return nil // 已经设置正确,无需修改 + } + + // 使用 hostnamectl 设置主机名(适用于 systemd 系统) + if _, err := exec.LookPath("hostnamectl"); err == nil { + cmd := exec.Command("hostnamectl", "set-hostname", hostname) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("hostnamectl 设置主机名失败: %v", err) + } + } else { + // 传统方法:直接修改 /etc/hostname + if err := os.WriteFile("/etc/hostname", []byte(hostname+"\n"), 0644); err != nil { + return fmt.Errorf("写入 /etc/hostname 失败: %v", err) + } + + // 立即生效(需要内核支持) + cmd := exec.Command("sysctl", "kernel.hostname="+hostname) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + // 不返回错误,因为重启后会生效 + fmt.Printf("警告: 主机名将在重启后完全生效\n") + } + } + + // 更新 /etc/hosts,确保本机解析正确 + if err := updateHostsFile(hostname); err != nil { + fmt.Printf("警告: 更新 /etc/hosts 失败: %v\n", err) + } + + return nil +} + +// updateHostsFile 更新 /etc/hosts 文件中的本机映射 +func updateHostsFile(hostname string) error { + content, err := os.ReadFile("/etc/hosts") + if err != nil { + return err + } + + lines := strings.Split(string(content), "\n") + newLines := make([]string, 0, len(lines)) + hostnameSet := false + + for _, line := range lines { + // 跳过空行和注释 + if line == "" || strings.HasPrefix(line, "#") { + newLines = append(newLines, line) + continue + } + + fields := strings.Fields(line) + if len(fields) >= 2 && fields[0] == "127.0.1.1" { + // 替换 Ubuntu/Debian 风格的本地主机名 + newLines = append(newLines, "127.0.1.1\t"+hostname) + hostnameSet = true + } else if len(fields) >= 2 && fields[0] == "127.0.0.1" { + // 保留原行,但检查是否包含主机名 + if !strings.Contains(line, hostname) { + line = line + " " + hostname + } + newLines = append(newLines, line) + hostnameSet = true + } else { + newLines = append(newLines, line) + } + } + + // 如果没有找到合适的位置,添加一行 + if !hostnameSet { + newLines = append(newLines, "127.0.1.1\t"+hostname) + } + + return os.WriteFile("/etc/hosts", []byte(strings.Join(newLines, "\n")), 0644) +} diff --git a/internal/system/motd.go b/internal/system/motd.go new file mode 100644 index 0000000..2b4c404 --- /dev/null +++ b/internal/system/motd.go @@ -0,0 +1,47 @@ +package system + +import ( + "os" + "time" +) + +// SetMOTD 设置 /etc/motd 文件内容 +// 参数: content - MOTD 文本内容 +// 返回: error - 写入文件错误 +func SetMOTD(content string) error { + if content == "" { + // 如果内容为空,不清除现有 MOTD,避免误操作 + return nil + } + + // 添加时间和系统信息 + finalContent := "========================================\n" + finalContent += "SunHPC 集群管理系统\n" + finalContent += "时间: " + time.Now().Format("2006-01-02 15:04:05") + "\n" + finalContent += "========================================\n\n" + finalContent += content + + // 确保行尾有换行 + if content[len(content)-1] != '\n' { + finalContent += "\n" + } + + return os.WriteFile("/etc/motd", []byte(finalContent), 0644) +} + +// ClearMOTD 清空 MOTD +func ClearMOTD() error { + return os.WriteFile("/etc/motd", []byte{}, 0644) +} + +// AppendToMOTD 追加内容到 MOTD +func AppendToMOTD(additional string) error { + f, err := os.OpenFile("/etc/motd", os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + return err + } + defer f.Close() + + _, err = f.WriteString(additional + "\n") + return err +} diff --git a/internal/system/selinux.go b/internal/system/selinux.go new file mode 100644 index 0000000..80254bb --- /dev/null +++ b/internal/system/selinux.go @@ -0,0 +1,97 @@ +package system + +import ( + "fmt" + "os" + "os/exec" + "strings" +) + +// ConfigureSELinux 设置 SELinux 模式 +// 参数: mode - enforcing, permissive, disabled +// 返回: error - 配置错误 +func ConfigureSELinux(mode string) error { + if mode == "" { + return nil + } + + // 验证输入 + mode = strings.ToLower(strings.TrimSpace(mode)) + validModes := map[string]bool{ + "enforcing": true, + "permissive": true, + "disabled": true, + } + + if !validModes[mode] { + return fmt.Errorf("无效的 SELinux 模式: %s (可选: enforcing, permissive, disabled)", mode) + } + + // 检查 SELinux 是否可用 + if _, err := os.Stat("/selinux/enforce"); os.IsNotExist(err) { + if _, err := os.Stat("/sys/fs/selinux/enforce"); os.IsNotExist(err) { + return fmt.Errorf("系统不支持 SELinux 或未启用") + } + } + + // 临时生效 + if mode != "disabled" { // disabled 需要重启才能完全生效 + cmd := exec.Command("setenforce", mode) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("setenforce %s 失败: %v", mode, err) + } + } + + // 持久化配置 + return persistSELinuxMode(mode) +} + +// persistSELinuxMode 修改 /etc/selinux/config 实现持久化 +func persistSELinuxMode(mode string) error { + const selinuxConfig = "/etc/selinux/config" + + // 读取配置文件 + content, err := os.ReadFile(selinuxConfig) + if err != nil { + // 如果文件不存在,创建默认配置 + if os.IsNotExist(err) { + defaultConfig := fmt.Sprintf(`# This file controls the state of SELinux on the system. +# SELINUX= can take one of these three values: +# enforcing - SELinux security policy is enforced. +# permissive - SELinux prints warnings instead of enforcing. +# disabled - No SELinux policy is loaded. +SELINUX=%s +# SELINUXTYPE= can take one of three two values: +# targeted - Targeted processes are protected, +# minimum - Modification of targeted policy. Only selected processes are protected. +# mls - Multi Level Security protection. +SELINUXTYPE=targeted +`, mode) + return os.WriteFile(selinuxConfig, []byte(defaultConfig), 0644) + } + return err + } + + // 替换 SELINUX= 行 + lines := strings.Split(string(content), "\n") + for i, line := range lines { + if strings.HasPrefix(line, "SELINUX=") { + lines[i] = fmt.Sprintf("SELINUX=%s", mode) + break + } + } + + return os.WriteFile(selinuxConfig, []byte(strings.Join(lines, "\n")), 0644) +} + +// GetSELinuxMode 获取当前 SELinux 模式 +func GetSELinuxMode() (string, error) { + cmd := exec.Command("getenforce") + output, err := cmd.Output() + if err != nil { + return "", err + } + return strings.ToLower(strings.TrimSpace(string(output))), nil +} diff --git a/internal/system/ssh.go b/internal/system/ssh.go new file mode 100644 index 0000000..4984b98 --- /dev/null +++ b/internal/system/ssh.go @@ -0,0 +1,187 @@ +package system + +import ( + "fmt" + "os" + "os/exec" + "strings" + + "sunhpc/internal/config" +) + +// ConfigureSSH 配置 SSH 服务 +// 参数: cfg - config.SSHConfig 结构体 +// 返回: error - 配置错误 +func ConfigureSSH(cfg config.SSHConfig) error { + const sshdConfig = "/etc/ssh/sshd_config" + + // 读取现有配置 + content, err := os.ReadFile(sshdConfig) + if err != nil { + return fmt.Errorf("读取 sshd_config 失败: %v", err) + } + + // 备份原始配置 + backupPath := sshdConfig + ".sunhpc.bak" + if _, err := os.Stat(backupPath); os.IsNotExist(err) { + if err := os.WriteFile(backupPath, content, 0644); err != nil { + fmt.Printf("警告: 无法创建备份文件 %s: %v\n", backupPath, err) + } + } + + // 解析和修改配置 + lines := strings.Split(string(content), "\n") + newLines := make([]string, 0, len(lines)) + configMap := make(map[string]bool) + + // 处理每一行 + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + newLines = append(newLines, line) + continue + } + + parts := strings.Fields(trimmed) + if len(parts) < 2 { + newLines = append(newLines, line) + continue + } + + key := parts[0] + configMap[key] = true + + // 根据配置更新 + switch key { + case "PermitRootLogin": + if cfg.PermitRootLogin != "" { + newLines = append(newLines, fmt.Sprintf("PermitRootLogin %s", cfg.PermitRootLogin)) + } else { + newLines = append(newLines, line) + } + case "PasswordAuthentication": + if cfg.PasswordAuth != "" { + newLines = append(newLines, fmt.Sprintf("PasswordAuthentication %s", cfg.PasswordAuth)) + } else { + newLines = append(newLines, line) + } + default: + newLines = append(newLines, line) + } + } + + // 添加缺失的配置项 + if cfg.PermitRootLogin != "" && !configMap["PermitRootLogin"] { + newLines = append(newLines, fmt.Sprintf("PermitRootLogin %s", cfg.PermitRootLogin)) + } + if cfg.PasswordAuth != "" && !configMap["PasswordAuthentication"] { + newLines = append(newLines, fmt.Sprintf("PasswordAuthentication %s", cfg.PasswordAuth)) + } + + // 写入新配置 + newContent := strings.Join(newLines, "\n") + if err := os.WriteFile(sshdConfig, []byte(newContent), 0644); err != nil { + return fmt.Errorf("写入 sshd_config 失败: %v", err) + } + + // 测试配置语法 + if err := testSSHDConfig(); err != nil { + // 恢复备份 + if backup, err := os.ReadFile(backupPath); err == nil { + os.WriteFile(sshdConfig, backup, 0644) + } + return fmt.Errorf("SSH 配置语法错误: %v", err) + } + + // 重启 SSH 服务 + return restartSSHD() +} + +// testSSHDConfig 测试 sshd 配置语法 +func testSSHDConfig() error { + cmd := exec.Command("sshd", "-t") + cmd.Stderr = os.Stderr + return cmd.Run() +} + +// restartSSHD 重启 SSH 服务 +func restartSSHD() error { + // 尝试不同的服务管理器 + serviceMgrs := []struct { + name string + args []string + }{ + {"systemctl", []string{"restart", "sshd"}}, + {"systemctl", []string{"restart", "ssh"}}, + {"service", []string{"sshd", "restart"}}, + {"service", []string{"ssh", "restart"}}, + } + + for _, mgr := range serviceMgrs { + if _, err := exec.LookPath(mgr.name); err == nil { + cmd := exec.Command(mgr.name, mgr.args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err == nil { + return nil + } + } + } + + return fmt.Errorf("无法重启 SSH 服务,请手动重启") +} + +// AddSSHKey 添加 SSH 公钥到指定用户 +func AddSSHKey(username, pubkey string) error { + // 获取用户主目录 + homeDir, err := getUserHomeDir(username) + if err != nil { + return err + } + + sshDir := homeDir + "/.ssh" + authKeys := sshDir + "/authorized_keys" + + // 创建 .ssh 目录 + if err := os.MkdirAll(sshDir, 0700); err != nil { + return err + } + + // 追加公钥 + f, err := os.OpenFile(authKeys, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + return err + } + defer f.Close() + + _, err = f.WriteString(pubkey + "\n") + if err != nil { + return err + } + + // 修改所有权 + return chownRecursive(sshDir, username) +} + +// getUserHomeDir 获取用户主目录 +func getUserHomeDir(username string) (string, error) { + cmd := exec.Command("getent", "passwd", username) + output, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("用户 %s 不存在", username) + } + + parts := strings.Split(strings.TrimSpace(string(output)), ":") + if len(parts) >= 6 { + return parts[5], nil + } + return "", fmt.Errorf("无法获取用户 %s 的主目录", username) +} + +// chownRecursive 递归修改文件所有者 +func chownRecursive(path, username string) error { + cmd := exec.Command("chown", "-R", username+":"+username, path) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} diff --git a/internal/system/sysctl.go b/internal/system/sysctl.go new file mode 100644 index 0000000..5986e57 --- /dev/null +++ b/internal/system/sysctl.go @@ -0,0 +1,87 @@ +package system + +import ( + "bufio" + "fmt" + "os" + "os/exec" + "strings" + "time" +) + +// ConfigureSysctl 设置内核参数 +// 参数: params - 键值对映射,如 {"net.ipv4.ip_forward": "1"} +// 返回: error - 第一个失败的错误 +func ConfigureSysctl(params map[string]string) error { + if len(params) == 0 { + return nil + } + + // 首先应用临时配置 + for k, v := range params { + cmd := exec.Command("sysctl", "-w", fmt.Sprintf("%s=%s", k, v)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("设置 sysctl %s=%s 失败: %v", k, v, err) + } + } + + // 持久化配置到 /etc/sysctl.conf 或 /etc/sysctl.d/ + return appendToSysctlConf(params) +} + +// appendToSysctlConf 将参数写入持久化配置文件 +func appendToSysctlConf(params map[string]string) error { + const sysctlFile = "/etc/sysctl.d/99-sunhpc.conf" + + // 读取现有配置 + existing := make(map[string]bool) + if data, err := os.ReadFile(sysctlFile); err == nil { + scanner := bufio.NewScanner(strings.NewReader(string(data))) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + existing[strings.TrimSpace(parts[0])] = true + } + } + } + + // 构建新内容 + var content strings.Builder + content.WriteString("# SunHPC 系统优化配置\n") + content.WriteString("# 生成时间: " + time.Now().Format(time.RFC3339) + "\n\n") + + for k, v := range params { + // 跳过已存在的配置(避免重复) + if existing[k] { + continue + } + content.WriteString(fmt.Sprintf("%s = %s\n", k, v)) + } + + // 如果有新配置才写入 + if content.Len() > 100 { + if err := os.WriteFile(sysctlFile, []byte(content.String()), 0644); err != nil { + return err + } + // 应用持久化配置 + return exec.Command("sysctl", "--system").Run() + } + + return nil +} + +// GetSysctl 获取当前内核参数值 +func GetSysctl(key string) (string, error) { + cmd := exec.Command("sysctl", "-n", key) + output, err := cmd.Output() + if err != nil { + return "", err + } + return strings.TrimSpace(string(output)), nil +} diff --git a/internal/system/system.go b/internal/system/system.go new file mode 100644 index 0000000..6164d3d --- /dev/null +++ b/internal/system/system.go @@ -0,0 +1,178 @@ +package system + +import ( + "os" + "path/filepath" + "sunhpc/internal/config" + "sunhpc/internal/log" +) + +// Context 系统配置上下文,包含所有命令行参数 +type Context struct { + Force bool // 强制模式 + DryRun bool // 干运行模式 + Verbose bool // 详细输出 + Timeout int // 超时时间 + Backup string // 备份路径 + YesMode bool // 自动确认 +} + +// ApplyAll 应用所有系统配置 +func ApplyAll(cfg *config.SunHPCConfig) error { + log.Info("开始应用系统配置...") + + if err := SetHostnameWithContext(cfg.Hostname, nil); err != nil { + log.Warnf("设置主机名失败: %v", err) + } + + if err := SetMOTDWithContext(cfg.MOTD, nil); err != nil { + log.Warnf("设置 MOTD 失败: %v", err) + } + + if err := ConfigureSysctlWithContext(cfg.Sysctl, nil); err != nil { + log.Warnf("配置 sysctl 失败: %v", err) + } + + if err := ConfigureSELinuxWithContext(cfg.SELinux, nil); err != nil { + log.Warnf("配置 SELinux 失败: %v", err) + } + + if err := ConfigureSSHWithContext(cfg.SSH, nil); err != nil { + log.Warnf("配置 SSH 失败: %v", err) + } + + log.Info("系统配置应用完成") + return nil +} + +// SetHostnameWithContext 设置系统主机名,带上下文参数 +func SetHostnameWithContext(hostname string, ctx *Context) error { + if ctx != nil && ctx.DryRun { + log.Infof("[干运行] 设置主机名为: %s", hostname) + return nil + } + + if hostname == "" { + return nil + } + + // 检查是否需要强制设置 + current, _ := os.Hostname() + if current == hostname && (ctx == nil || !ctx.Force) { + log.Infof("主机名已是 '%s',跳过设置", hostname) + return nil + } + + log.Infof("设置主机名为: %s", hostname) + return SetHostname(hostname) +} + +// SetMOTDWithContext 设置 MOTD,带上下文参数 +func SetMOTDWithContext(content string, ctx *Context) error { + if ctx != nil && ctx.DryRun { + log.Info("[干运行] 设置 MOTD") + return nil + } + + if content == "" { + return nil + } + + // 备份现有文件 + if ctx != nil && ctx.Backup != "" { + backupMOTD(ctx.Backup) + } + + log.Info("更新 /etc/motd") + return SetMOTD(content) +} + +// ConfigureSysctlWithContext 配置内核参数,带上下文参数 +func ConfigureSysctlWithContext(params map[string]string, ctx *Context) error { + if ctx != nil && ctx.DryRun { + log.Info("[干运行] 配置 sysctl 参数") + return nil + } + + if len(params) == 0 { + return nil + } + + // 备份现有配置 + if ctx != nil && ctx.Backup != "" { + backupSysctl(ctx.Backup) + } + + return ConfigureSysctl(params) +} + +// ConfigureSELinuxWithContext 配置 SELinux,带上下文参数 +func ConfigureSELinuxWithContext(mode string, ctx *Context) error { + if ctx != nil && ctx.DryRun { + log.Infof("[干运行] 设置 SELinux 模式为: %s", mode) + return nil + } + + if mode == "" { + return nil + } + + // 检查当前模式 + current, _ := GetSELinuxMode() + if current == mode && (ctx == nil || !ctx.Force) { + log.Infof("SELinux 已是 '%s' 模式,跳过设置", mode) + return nil + } + + log.Infof("设置 SELinux 模式为: %s", mode) + return ConfigureSELinux(mode) +} + +// ConfigureSSHWithContext 配置 SSH,带上下文参数 +func ConfigureSSHWithContext(cfg config.SSHConfig, ctx *Context) error { + if ctx != nil && ctx.DryRun { + log.Info("[干运行] 配置 SSH 服务") + return nil + } + + // 备份配置文件 + if ctx != nil && ctx.Backup != "" { + backupSSHConfig(ctx.Backup) + } + + log.Info("配置 SSH 服务") + return ConfigureSSH(cfg) +} + +// 备份函数 +func backupMOTD(backupDir string) error { + backupPath := filepath.Join(backupDir, "motd."+filepath.Base(os.Args[0])+".bak") + if err := os.MkdirAll(backupDir, 0755); err != nil { + return err + } + return copyFile("/etc/motd", backupPath) +} + +func backupSysctl(backupDir string) error { + backupPath := filepath.Join(backupDir, "sysctl.conf.bak") + if err := os.MkdirAll(backupDir, 0755); err != nil { + return err + } + return copyFile("/etc/sysctl.conf", backupPath) +} + +func backupSSHConfig(backupDir string) error { + backupPath := filepath.Join(backupDir, "sshd_config.bak") + if err := os.MkdirAll(backupDir, 0755); err != nil { + return err + } + return copyFile("/etc/ssh/sshd_config", backupPath) +} + +func copyFile(src, dst string) error { + data, err := os.ReadFile(src) + if err != nil { + return err + } + return os.WriteFile(dst, data, 0644) +} diff --git a/internal/template/engine.go b/internal/template/engine.go new file mode 100644 index 0000000..ca367e0 --- /dev/null +++ b/internal/template/engine.go @@ -0,0 +1,61 @@ +package template + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "text/template" + + "sunhpc/internal/config" + "sunhpc/internal/log" +) + +// RenderAndExecute 从模板目录加载模板,渲染后生成临时脚本并执行 +// tmplName: 模板文件名(位于 /etc/sunhpc/tmpl.d/) +// data: 模板变量 +func RenderAndExecute(tmplName string, data interface{}) error { + tmplPath := filepath.Join(config.TmplDir, tmplName) + if _, err := os.Stat(tmplPath); err != nil { + return fmt.Errorf("模板文件不存在: %s", tmplPath) + } + + content, err := os.ReadFile(tmplPath) + if err != nil { + return err + } + + t, err := template.New(tmplName).Parse(string(content)) + if err != nil { + return err + } + + var buf bytes.Buffer + if err := t.Execute(&buf, data); err != nil { + return err + } + + // 生成临时脚本 + tmpFile, err := os.CreateTemp("/tmp", "sunhpc-*.sh") + if err != nil { + return err + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.Write(buf.Bytes()); err != nil { + tmpFile.Close() + return err + } + tmpFile.Close() + + if err := os.Chmod(tmpFile.Name(), 0755); err != nil { + return err + } + + log.Infof("执行模板脚本: %s", tmpFile.Name()) + cmd := exec.Command("/bin/bash", tmpFile.Name()) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..38d95a2 --- /dev/null +++ b/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "os" + "sunhpc/cmd" +) + +func main() { + if err := cmd.Execute(); err != nil { + os.Exit(1) + } +} + +/* + // 初始化日志(verbose=true 显示调试信息) + log.Init(true) + + // 基础用法 + log.Info("服务启动成功") + log.Infof("用户 %s 已添加", "testuser") + + log.Warn("磁盘使用率超过 80%") + log.Warnf("节点 %s 网络延迟过高", "node01") + + log.Error("配置文件解析失败") + log.Errorf("无法连接到数据库: %v", err) + + log.Debug("正在执行命令: ssh root@192.168.1.1") + log.Debugf("加载了 %d 个节点配置", len(nodes)) + + // 致命错误 + if err != nil { + log.Fatal("初始化失败: ", err) + } + + // 临时禁用颜色(例如输出重定向时) + if !isTerminal { + log.EnableColor(false) + } + + // 设置日志级别 + log.SetLevel(log.WarnLevel) // 只显示警告及以上级别 + + // 启用调用者信息 + log.EnableCaller(true) +*/ diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go new file mode 100644 index 0000000..cd8119c --- /dev/null +++ b/pkg/utils/utils.go @@ -0,0 +1,18 @@ +package utils + +import ( + "os" + "os/exec" +) + +// CommandExists 检查命令是否存在 +func CommandExists(cmd string) bool { + _, err := exec.LookPath(cmd) + return err == nil +} + +// FileExists 检查文件是否存在 +func FileExists(path string) bool { + _, err := os.Stat(path) + return err == nil || !os.IsNotExist(err) +} diff --git a/sunhpc b/sunhpc new file mode 100755 index 0000000..bbb83fd Binary files /dev/null and b/sunhpc differ diff --git a/test_db.sh b/test_db.sh new file mode 100644 index 0000000..651bf5c --- /dev/null +++ b/test_db.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +echo "=========================================" +echo "SunHPC 数据库测试脚本" +echo "=========================================" +echo "" + +# 1. 清理旧数据库 +echo "[1/4] 清理旧数据库..." +rm -f /var/lib/sunhpc/sunhpc.db* +rm -rf /var/lib/sunhpc/ +echo "✓ 清理完成" +echo "" + +# 2. 初始化数据库 +echo "[2/4] 初始化数据库..." +./sunhpc init database -v --force +if [ $? -ne 0 ]; then + echo "✗ 初始化失败" + exit 1 +fi +echo "✓ 初始化完成" +echo "" + +# 3. 添加节点 +echo "[3/4] 添加节点..." +./sunhpc node add node1 --cpus 32 --memory 128 --disk 1000 --os "CentOS 7.9" --kernel "3.10.0-1160.el7.x86_64" +if [ $? -ne 0 ]; then + echo "✗ 添加节点失败" + exit 1 +fi +echo "✓ 节点添加完成" +echo "" + +# 4. 查询节点 +echo "[4/4] 查询节点列表..." +./sunhpc node list +if [ $? -ne 0 ]; then + echo "✗ 查询失败" + exit 1 +fi +echo "" + +echo "=========================================" +echo "测试完成!" +echo "========================================="