diff --git a/autofs.yaml b/autofs.yaml new file mode 100644 index 0000000..4ab87e0 --- /dev/null +++ b/autofs.yaml @@ -0,0 +1,29 @@ +description: AutoFS server for SunHPC clusters +copyright: | + Copyright (c) 2026 SunHPC Project. + Licensed under Apache 2.0. + +stages: + post: + - type: file + path: /etc/auto.master + content: | + /share /etc/auto.share --timeout=1200 + /home /etc/auto.home --timeout=1200 + + - type: file + path: /etc/auto.share + content: | + apps {{ .Node.Hostname }}.{{ .Cluster.Domain }}:/export/& + + - type: script + content: | + mkdir -p /export/apps + echo "AutoFS 配置已生成" + + configure: + - type: script + condition: "{{ if .Node.OldHostname }}true{{ end }}" + content: | + sed -i 's/{{ .Node.OldHostname }}/{{ .Node.Hostname }}/g' /etc/auto.share + systemctl restart autofs \ No newline at end of file diff --git a/cmd/init/config.go b/cmd/init/config.go index 1e4a24a..c4d3bc3 100644 --- a/cmd/init/config.go +++ b/cmd/init/config.go @@ -11,31 +11,58 @@ import ( "github.com/spf13/cobra" ) -var configCmd = &cobra.Command{ - Use: "config", - Short: "生成基础配置文件", - Long: "创建 /etc/sunhpc 目录并生成所有默认配置文件(若目录已存在则跳过)", - RunE: func(cmd *cobra.Command, args []string) error { - if err := auth.RequireRoot(); err != nil { - return err - } +// NewConfigCmd 创建 "init config" 命令 +func NewConfigCmd() *cobra.Command { + var ( + force bool + path string + verbose bool + ) - // 检查目录是否已存在 - if _, err := os.Stat(config.BaseDir); err == nil { - log.Warnf("配置目录 %s 已存在,跳过初始化", config.BaseDir) + cmd := &cobra.Command{ + Use: "config", + Short: "生成默认配置文件", + Long: ` + 在指定路径生成 SunHPC 默认配置文件 (sunhpc.yaml) + + 示例: + sunhpc init config # 生成默认配置文件 + sunhpc init config -f # 强制覆盖已有配置文件 + sunhpc init config -p /etc/sunhpc/sunhpc.yaml # 指定路径 + `, + + Annotations: map[string]string{ + "require-root": "true", // 假设需要 root(你可自定义策略) + }, + + RunE: func(cmd *cobra.Command, args []string) error { + if err := auth.RequireRoot(); err != nil { + return err + } + + if path == "" { + path = "/etc/sunhpc/sunhpc.yaml" + } + + if !force { + if _, err := os.Stat(path); err == nil { + return fmt.Errorf("配置文件已存在: %s (使用 --force 覆盖)", path) + } + } + + if err := config.WriteDefaultConfig(path); err != nil { + return fmt.Errorf("写入配置失败: %w", err) + } + + log.Infof("✅ 配置文件已生成: %s", path) return nil - } + }, + } - log.Info("初始化 SunHPC 配置目录...") - if err := config.InitDirs(); err != nil { - return fmt.Errorf("创建目录失败: %v", err) - } + // 定义局部 flags + cmd.Flags().BoolVarP(&force, "force", "f", false, "强制覆盖已有配置文件") + cmd.Flags().StringVarP(&path, "path", "p", "", "指定配置文件路径") + cmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "显示详细日志") - if err := config.CreateDefaultConfigs(); err != nil { - return fmt.Errorf("生成默认配置文件失败: %v", err) - } - - log.Info("配置文件已生成,请根据需要编辑 /etc/sunhpc/ 下的 YAML 文件") - return nil - }, + return cmd } diff --git a/cmd/init/database.go b/cmd/init/database.go index 942c498..b646755 100644 --- a/cmd/init/database.go +++ b/cmd/init/database.go @@ -1,193 +1,47 @@ package initcmd import ( - "bufio" - "fmt" - "os" - "path/filepath" - "strings" - "sunhpc/internal/auth" "sunhpc/internal/db" "sunhpc/internal/log" "github.com/spf13/cobra" - "github.com/spf13/viper" ) -const ( - appName string = "sunhpc" - defaultDBPath string = "/var/lib/sunhpc" - defaultDBName string = "sunhpc.db" -) +func NewDatabaseCmd() *cobra.Command { + var force bool -var ( - forceDB bool - dbPath string - dbName string - dbFullPath string -) - -func initDBPathWithViper() error { - /* - 从 Viper 配置文件获取数据库路径 - 配置文件里的键要和 Viper.GetXXX 的键对应 - 配置文件格式: - db: - path: "/tmp/sunhpc" # 自定义数据库路径 - name: "my_sunhpc.db" # 自定义数据库名 - */ - - log.Infof("从 Viper 配置文件获取数据库路径...") - - // ========== 第一步:设置 Viper 配置文件规则(核心) ========== - // 1. 设置Viper基础规则 - viper.SetConfigType("yaml") // 配置文件类型 - viper.SetConfigName("config") // 配置文件名(不带后缀) - viper.SetEnvPrefix(appName) // 环境变量前缀:SUNHPC_ - viper.AutomaticEnv() // 自动读取环境变量(可选,增强兼容性) - - // 2. 添加配置文件搜索目录(Viper 会按顺序查找,找到第一个就停止) - // 优先级:当前目录 → 用户级目录 → 系统级目录 - - // ① 当前目录(开发/测试常用) - viper.AddConfigPath(".") - - // ② Linux/macOS 用户级目录(~/.config/sunhpc/) - if homeDir, err := os.UserHomeDir(); err == nil { - viper.AddConfigPath(filepath.Join(homeDir, ".config", appName)) - } - // ③ Linux/macOS 系统级目录(/etc/sunhpc/) - viper.AddConfigPath(filepath.Join("/etc", appName)) - - // ========== 第二步:设置默认值(最低优先级) ========== - viper.SetDefault("db.path", defaultDBPath) - viper.SetDefault("db.name", defaultDBName) - - // ========== 第三步:绑定环境变量(优先级高于默认值,低于配置文件) ========== - viper.BindEnv("db.path", "DB_PATH") // 绑定 SUNHPC_DB_PATH → db.path - viper.BindEnv("db.name", "DB_NAME") // 绑定 SUNHPC_DB_NAME → db.name - - // ========== 第四步:读取配置文件(优先级高于环境变量,低于默认值) ========== - if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { - log.Info("未找到配置文件,将使用环境变量/默认值") - return nil // 配置文件存在但格式错误,返回错误 - } - log.Warnf("读取配置文件失败: %v", err) - return fmt.Errorf("读取配置文件失败: %w", err) - } - - log.Infof("成功加载配置文件: %s", viper.ConfigFileUsed()) - return nil -} - -func initDBPath() error { - - // 1. 从 Viper 配置文件获取数据库路径(加载配置文件->环境变量->默认值) - if err := initDBPathWithViper(); err != nil { - return fmt.Errorf("Viper初始化数据库失败: %v", err) - } - - // 2. 从Viper获取数据库路径 - dbPath = viper.GetString("db.path") - dbName = viper.GetString("db.name") - - // 3. 拼接数据库路径 - dbFullPath = filepath.Join(dbPath, dbName) - log.Infof("数据库完整路径: %s", dbFullPath) - - // 3. 检查数据库文件是否存在 - dir := filepath.Dir(dbFullPath) - // 4. 检查数据库目录是否存在 - if _, err := os.Stat(dir); os.IsNotExist(err) { - log.Infof("数据库目录不存在,创建目录: %s", dir) - - if err := os.MkdirAll(dir, 0755); err != nil { - return fmt.Errorf("创建数据库目录失败: %v", err) - } - } - return nil -} - -var databaseCmd = &cobra.Command{ - Use: "database", - Short: "初始化数据库", - Long: `初始化SQLite数据库,创建所有表结构和默认数据。 + cmd := &cobra.Command{ + Use: "database", + Short: "初始化数据库", + Long: `初始化SQLite数据库,创建所有表结构和默认数据。 示例: sunhpc init database # 初始化数据库 sunhpc init database --force # 强制重新初始化`, - RunE: func(cmd *cobra.Command, args []string) error { - if err := auth.RequireRoot(); err != nil { - return err - } - - log.Debug("初始化数据库...") - - if err := initDBPath(); err != nil { - return fmt.Errorf("初始化数据库路径失败: %w", err) - } - - log.Debugf("数据库目录存在: %s", dbPath) - - // 强制模式:用户确认 - if forceDB { - log.Warn("⚠️ 警告:强制重新初始化将清空数据库中的所有数据!") - fmt.Printf("数据库路径: %s\n", dbFullPath) - fmt.Print("确认要重新初始化数据库吗?这将删除所有现有数据。(Y/yes): ") - - reader := bufio.NewReader(os.Stdin) - input, err := reader.ReadString('\n') - if err != nil { - return fmt.Errorf("读取用户输入失败: %v", err) + Annotations: map[string]string{ + "skip-db-check": "true", // 标记此命令跳过数据库检查 + }, + RunE: func(cmd *cobra.Command, args []string) error { + if err := auth.RequireRoot(); err != nil { + return err } - input = strings.TrimSpace(strings.ToLower(input)) - if input != "y" && input != "yes" { - log.Info("操作已取消") - return nil + log.Info("初始化数据库...") + if force { + log.Warn("⚠️ 警告:强制重新初始化将清空数据库中的所有数据!") } - log.Info("用户确认重新初始化数据库") - } - - // 数据库存在且不是强制模式则跳过初始化 - if _, err := os.Stat(dbFullPath); err == nil && !forceDB { - log.Infof("数据库文件已存在: %s", dbFullPath) + dbInst := db.MustGetDB() // panic if fail (ok for CLI tool) + if err := dbInst.InitSchema(force); err != nil { + return err + } + log.Info("数据库初始化完成") return nil - } + }, + } - // 初始化数据库(使用配置的路径) - database, err := db.GetInstanceWithConfig(dbPath, dbName) - if err != nil { - return fmt.Errorf("初始化数据库失败: %v", err) - } - defer database.Close() - - // 如果是强制模式,设置强制重新初始化标志 - if forceDB { - database.SetForceInit(true) - log.Info("强制重新初始化数据库表...") - - // 关闭现有连接以触发重新连接 - if err := database.CloseConnection(); err != nil { - return fmt.Errorf("关闭现有数据库连接失败: %v", err) - } - - // 重新连接并初始化 - if err := database.Connect(); err != nil { - return fmt.Errorf("强制重新初始化数据库失败: %v", err) - } - } - - log.Infof("数据库初始化成功: %s", dbFullPath) - return nil - }, -} - -func init() { - databaseCmd.Flags().BoolVarP(&forceDB, "force", "f", false, "强制重新初始化,删除现有数据库") - Cmd.AddCommand(databaseCmd) + cmd.Flags().BoolVarP(&force, "force", "f", false, "强制重新初始化") + return cmd } diff --git a/cmd/init/init.go b/cmd/init/init.go index 6506f0d..f659242 100644 --- a/cmd/init/init.go +++ b/cmd/init/init.go @@ -4,6 +4,7 @@ import ( "github.com/spf13/cobra" ) +// 仅定义 Cmd, 注册子命令,只负责组装命令树,尽量不包含业务逻辑 var Cmd = &cobra.Command{ Use: "init", Short: "初始化集群配置", @@ -11,7 +12,7 @@ var Cmd = &cobra.Command{ } func init() { - Cmd.AddCommand(configCmd) - Cmd.AddCommand(systemCmd) - Cmd.AddCommand(serviceCmd) + // 注册所有子命令(通过工厂函数创建, 例如 DatabaseCmd()) + Cmd.AddCommand(NewDatabaseCmd()) + Cmd.AddCommand(NewConfigCmd()) } diff --git a/cmd/init/service.go b/cmd/init/service.go deleted file mode 100644 index 44bbe96..0000000 --- a/cmd/init/service.go +++ /dev/null @@ -1,37 +0,0 @@ -package initcmd - -import ( - "fmt" - - "sunhpc/internal/auth" - "sunhpc/internal/config" - "sunhpc/internal/log" - "sunhpc/internal/service" - - "github.com/spf13/cobra" -) - -var serviceCmd = &cobra.Command{ - Use: "service", - Short: "根据配置文件初始化服务", - Long: `读取 /etc/sunhpc/services.yaml 并部署/配置相关服务。 -支持 HTTPD、TFTPD、DHCPD 等。`, - RunE: func(cmd *cobra.Command, args []string) error { - if err := auth.RequireRoot(); err != nil { - return err - } - - svcCfg, err := config.LoadServices() - if err != nil { - return fmt.Errorf("加载 services.yaml 失败: %v", err) - } - - log.Info("开始部署服务...") - if err := service.Deploy(svcCfg); err != nil { - return fmt.Errorf("服务部署失败: %v", err) - } - - log.Info("服务初始化完成") - return nil - }, -} diff --git a/cmd/init/system.go b/cmd/init/system.go deleted file mode 100644 index 79ed85d..0000000 --- a/cmd/init/system.go +++ /dev/null @@ -1,49 +0,0 @@ -package initcmd - -import ( - "fmt" - - "sunhpc/internal/auth" - "sunhpc/internal/config" - "sunhpc/internal/system" - - "github.com/spf13/cobra" -) - -var ( - dryRun bool // --dry-run -n: 仅模拟执行,不实际应用 - verbose bool // --verbose -v: 启用详细日志输出 -) - -var systemCmd = &cobra.Command{ - Use: "system [flags]", - Short: "根据配置文件初始化系统", - Long: `读取 /etc/sunhpc/sunhpc.yaml 中的系统配置项并应用到当前节点。 - 示例: - sunhpc init system # 应用所有配置项 - sunhpc init system --dry-run # 仅模拟执行,不实际应用 - sunhpc init system --verbose # 启用详细日志输出 - `, - RunE: func(cmd *cobra.Command, args []string) error { - // 权限检查:必须以 root 或 sudo 运行 - if err := auth.RequireRoot(); err != nil { - return err - } - - // 加载主配置 - cfg, err := config.LoadSunHPC() - if err != nil { - return fmt.Errorf("加载 sunhpc.yaml 失败: %v", err) - } - - // 统一应用所有配置 - return system.ApplyAll(cfg) - }, -} - -// init 初始化 systemCmd 的标志,添加长参数和段参数. -func init() { - // 注册长参数, 布尔参数 - systemCmd.Flags().BoolVarP(&dryRun, "dry-run", "n", false, "仅模拟执行,不实际应用") - systemCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "启用详细日志输出") -} diff --git a/cmd/node.go b/cmd/node.go deleted file mode 100644 index 8118fbe..0000000 --- a/cmd/node.go +++ /dev/null @@ -1,113 +0,0 @@ -package cmd - -import ( - "fmt" - - "sunhpc/internal/db" - "sunhpc/internal/log" - - "github.com/spf13/cobra" -) - -var nodeCmd = &cobra.Command{ - Use: "node", - Short: "节点管理", - Long: "管理集群节点,包括添加、删除、查询等操作", -} - -var nodeListCmd = &cobra.Command{ - Use: "list", - Short: "列出所有节点", - RunE: func(cmd *cobra.Command, args []string) error { - log.Info("查询节点列表...") - - // 获取数据库实例(自动使用之前配置的路径) - database, err := db.GetInstance() - if err != nil { - return fmt.Errorf("获取数据库连接失败: %v", err) - } - defer database.Close() - - // 执行查询 - _, err = database.Execute("SELECT id, name, rack, rank, cpus, memory, disk, os, kernel FROM nodes ORDER BY name") - if err != nil { - return fmt.Errorf("查询节点失败: %v", err) - } - - // 获取所有结果 - rows, err := database.FetchAll() - if err != nil { - return fmt.Errorf("获取结果失败: %v", err) - } - - if len(rows) == 0 { - log.Info("暂无节点数据") - return nil - } - - // 打印结果 - fmt.Printf("%-5s %-20s %-8s %-8s %-8s %-10s %-10s %-10s\n", - "ID", "名称", "机架", "排名", "CPU", "内存", "磁盘", "操作系统") - fmt.Println("----------------------------------------------------------------------------------") - for _, row := range rows { - fmt.Printf("%-5v %-20s %-8v %-8v %-8v %-10v %-10v %-10s\n", - row["id"], row["name"], row["rack"], row["rank"], - row["cpus"], row["memory"], row["disk"], row["os"]) - } - - return nil - }, -} - -var nodeAddCmd = &cobra.Command{ - Use: "add ", - Short: "添加节点", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - nodeName := args[0] - log.Infof("添加节点: %s", nodeName) - - database, err := db.GetInstance() - if err != nil { - return fmt.Errorf("获取数据库连接失败: %v", err) - } - defer database.Close() - - // 插入节点 - _, err = database.Execute( - "INSERT INTO nodes (name, rack, rank, cpus, memory, disk, os, kernel) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - nodeName, rack, rank, cpus, memory, disk, nodeOS, kernel, - ) - if err != nil { - return fmt.Errorf("添加节点失败: %v", err) - } - - log.Infof("节点 %s 添加成功", nodeName) - return nil - }, -} - -var ( - rack int - rank int - cpus int - memory int - disk int - nodeOS string - kernel string -) - -func init() { - // 添加子命令 - nodeCmd.AddCommand(nodeListCmd) - nodeCmd.AddCommand(nodeAddCmd) - - // 添加参数 - nodeAddCmd.Flags().IntVar(&rack, "rack", 0, "机架号") - nodeAddCmd.Flags().IntVar(&rank, "rank", 0, "排名") - nodeAddCmd.Flags().IntVar(&cpus, "cpus", 0, "CPU核心数") - nodeAddCmd.Flags().IntVar(&memory, "memory", 0, "内存大小(GB)") - nodeAddCmd.Flags().IntVar(&disk, "disk", 0, "磁盘大小(GB)") - nodeAddCmd.Flags().StringVar(&nodeOS, "os", "", "操作系统") - nodeAddCmd.Flags().StringVar(&kernel, "kernel", "", "内核版本") -} diff --git a/cmd/root.go b/cmd/root.go index c746cc3..f40dbe5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,12 +1,16 @@ package cmd import ( + "os" + "strings" initcmd "sunhpc/cmd/init" "sunhpc/cmd/soft" + "sunhpc/cmd/tmpl" + "sunhpc/internal/auth" + "sunhpc/internal/config" "sunhpc/internal/log" "github.com/spf13/cobra" - "github.com/spf13/viper" ) var ( @@ -15,26 +19,85 @@ var ( noColor bool ) +func checkDB() error { + cfg, err := config.LoadConfig() + if err != nil { + log.Fatal("加载配置失败: ", err) + } + + // 统一转为小写,避免用户输入错误 + dbType := strings.ToLower(cfg.DB.Type) + + // 打印配置(调试用) + log.Debugf("数据库类型: %s", dbType) + log.Debugf("数据库名称: %s", cfg.DB.Name) + log.Debugf("数据库路径: %s", cfg.DB.Path) + log.Debugf("数据库用户: %s", cfg.DB.User) + log.Debugf("数据库主机: %s", cfg.DB.Host) + log.Debugf("数据库套接字: %s", cfg.DB.Socket) + log.Debugf("数据库详细日志: %v", cfg.DB.Verbose) + + // 支持 sqlite,mysql的常见别名 + isSQLite := dbType == "sqlite" || dbType == "sqlite3" + isMySQL := dbType == "mysql" + + // 检查数据库类型,只允许 sqlite 和 mysql + if !isSQLite && !isMySQL { + log.Fatalf("不支持的数据库类型: %s(仅支持 sqlite、sqlite3、mysql)", dbType) + } + + // 检查数据库路径是否存在 + if isSQLite { + if _, err := os.Stat(cfg.DB.Path); os.IsNotExist(err) { + log.Warnf("SQLite 数据库路径 %s 不存在", cfg.DB.Path) + log.Fatalf("必须先执行 'sunhpc init database' 初始化数据库") + } + } + return nil +} + var rootCmd = &cobra.Command{ Use: "sunhpc", Short: "SunHPC - HPC集群一体化运维工具", PersistentPreRun: func(cmd *cobra.Command, args []string) { - // 初始化日志 + // 初始化日志(verbose=false 不显示调试信息) log.Init(verbose) // 是否禁用颜色 - if noColor { - log.EnableColor(false) + log.EnableColor(!noColor) + + log.Debugf("当前命令 Annotations: %+v", cmd.Annotations) + + // 检查当前命令是否标记为跳过 DB 检查 + if cmd.Annotations["skip-db-check"] == "true" { + log.Debugf("当前命令 %s 标记为跳过 DB 检查", cmd.Name()) + return + } else { + // 检查数据库 + if err := checkDB(); err != nil { + log.Fatalf("数据库检查失败: %v", err) + } } - log.Debugf("命令: %s", cmd.Name()) + // 需要 root 权限 + if cmd.Annotations["require-root"] == "true" { + if err := auth.RequireRoot(); err != nil { + log.Fatalf("需要 root 权限: %v", err) + } + } + + log.Debugf("当前命令: %s", cmd.Name()) log.Debugf("详细模式: %v", verbose) + log.Debugf("禁用颜色: %v", noColor) }, PersistentPostRun: func(cmd *cobra.Command, args []string) { // 同步日志 log.Sync() log.Close() }, + Run: func(cmd *cobra.Command, args []string) { + cmd.Help() + }, } func Execute() error { @@ -42,32 +105,12 @@ func Execute() error { } func init() { - cobra.OnInitialize(initConfig) - rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "配置文件路径 (默认为 /etc/sunhpc/sunhpc.yaml)") rootCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "启用详细日志输出") rootCmd.PersistentFlags().BoolVar(&noColor, "no-color", false, "禁用彩色输出") - // 注册一级子命令 + // 注册一级子命令下的子命令树 rootCmd.AddCommand(initcmd.Cmd) rootCmd.AddCommand(soft.Cmd) - rootCmd.AddCommand(nodeCmd) -} - -func initConfig() { - if cfgFile != "" { - viper.SetConfigFile(cfgFile) - } else { - viper.AddConfigPath("/etc/sunhpc") - viper.SetConfigType("yaml") - viper.SetConfigName("sunhpc") - } - - viper.AutomaticEnv() - - if err := viper.ReadInConfig(); err == nil { - log.Infof("使用配置文件: %s", viper.ConfigFileUsed()) - } else { - log.Debugf("未找到配置文件: %v", err) - } + rootCmd.AddCommand(tmpl.Cmd) } diff --git a/cmd/tmpl/dump.go b/cmd/tmpl/dump.go new file mode 100644 index 0000000..7ece475 --- /dev/null +++ b/cmd/tmpl/dump.go @@ -0,0 +1,58 @@ +package tmpl + +import ( + "fmt" + + "sunhpc/internal/log" + "sunhpc/internal/templating" + + "github.com/spf13/cobra" +) + +func newDumpCmd() *cobra.Command { + var output string + + cmd := &cobra.Command{ + Use: "dump ", + Short: "导出内置模板到文件", + Long: ` + 将内置的 YAML 模板导出为可编辑的文件。 + + 示例: + sunhpc tmpl dump autofs --output ./my-autofs.yaml`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name := args[0] + + // 检查模板是否存在 + available, _ := templating.ListEmbeddedTemplates() + found := false + for _, n := range available { + if n == name { + found = true + break + } + } + if !found { + return fmt.Errorf("内置模板 '%s' 不存在。可用模板: %v", name, available) + } + + outPath := output + if outPath == "" { + outPath = name + ".yaml" + } + + if err := templating.DumpEmbeddedTemplateToFile(name, outPath); err != nil { + return err + } + + log.Infof("内置模板 '%s' 已导出到: %s", name, outPath) + log.Infof("你可以编辑此文件,然后用以下命令使用它:") + log.Infof(" sunhpc tmpl render %s -f %s [flags]", name, outPath) + return nil + }, + } + + cmd.Flags().StringVarP(&output, "output", "o", "", "输出文件路径(默认: .yaml)") + return cmd +} diff --git a/cmd/tmpl/init.go b/cmd/tmpl/init.go new file mode 100644 index 0000000..fa9ad5d --- /dev/null +++ b/cmd/tmpl/init.go @@ -0,0 +1,16 @@ +// cmd/tmpl/init.go +package tmpl + +import "github.com/spf13/cobra" + +// Cmd 是 sunhpc tmpl 的根命令 +var Cmd = &cobra.Command{ + Use: "tmpl", + Short: "管理配置模板", + Long: "从 YAML 模板生成配置文件或脚本,支持变量替换和多阶段执行", +} + +func init() { + Cmd.AddCommand(newRenderCmd()) + Cmd.AddCommand(newDumpCmd()) +} diff --git a/cmd/tmpl/render.go b/cmd/tmpl/render.go new file mode 100644 index 0000000..5e18910 --- /dev/null +++ b/cmd/tmpl/render.go @@ -0,0 +1,96 @@ +package tmpl + +import ( + "fmt" + + "sunhpc/internal/log" + "sunhpc/internal/templating" + + "github.com/spf13/cobra" +) + +var ( + tmplFile string + hostname string + domain string + oldHostname string + ip string + clusterName string + outputRoot string +) + +func newRenderCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "render ", + Short: "渲染配置模板", + Long: "根据 YAML 模板和上下文变量生成配置文件或脚本", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + tmplName := args[0] + var template *templating.Template + var err error + + // 优先使用 -f 指定的外部模版文件 + if tmplFile != "" { + template, err = templating.LoadTemplate(tmplFile) + if err != nil { + return fmt.Errorf("加载外部模板失败: %w", err) + } + log.Infof("✅ 外部模板 '%s' 已加载\n", tmplFile) + } else { + // 否则从内置模板加载 + template, err = templating.LoadEmbeddedTemplate(tmplName) + if err != nil { + return err + } + log.Infof("✅ 内置模板 '%s' 已加载\n", tmplName) + } + + ctx := templating.Context{ + Node: templating.NodeInfo{ + Hostname: hostname, + OldHostname: oldHostname, + Domain: domain, + IP: ip, + }, + Cluster: templating.ClusterInfo{ + Name: clusterName, + }, + } + + rendered, err := template.Render(ctx) + if err != nil { + return fmt.Errorf("模板渲染失败: %w", err) + } + + // 处理 post 阶段 + if steps, ok := rendered["post"]; ok { + fmt.Println(">>> 执行 post 阶段") + if err := templating.WriteFiles(steps, outputRoot); err != nil { + return err + } + templating.PrintScripts(steps) + } + + // 处理 configure 阶段 + if steps, ok := rendered["configure"]; ok { + fmt.Println(">>> 执行 configure 阶段") + templating.PrintScripts(steps) + } + + fmt.Println("✅ 模板渲染完成") + return nil + }, + } + + cmd.Flags().StringVarP(&tmplFile, "file", "f", "", "指定模板文件路径(覆盖默认查找)") + cmd.Flags().StringVar(&hostname, "hostname", "", "节点主机名") + cmd.Flags().StringVar(&domain, "domain", "cluster.local", "DNS 域名") + cmd.Flags().StringVar(&oldHostname, "old-hostname", "", "旧主机名(用于迁移)") + cmd.Flags().StringVar(&ip, "ip", "", "节点 IP 地址") + cmd.Flags().StringVar(&clusterName, "cluster", "default", "集群名称") + cmd.Flags().StringVarP(&outputRoot, "output", "o", "/", "文件输出根目录") + + _ = cmd.MarkFlagRequired("hostname") + return cmd +} diff --git a/config.yaml b/config.yaml deleted file mode 100644 index 384d46b..0000000 --- a/config.yaml +++ /dev/null @@ -1,3 +0,0 @@ -db: - path: "/tmp/sunhpc" - name: "sunhpc.db" diff --git a/go.mod b/go.mod index b09e75f..5ee5dad 100644 --- a/go.mod +++ b/go.mod @@ -9,8 +9,10 @@ require ( ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/fatih/color v1.18.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index 17e108b..59a83a2 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -7,6 +9,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= diff --git a/internal/config/config.go b/internal/config/config.go index 3d51259..9c89bb7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,26 +1,135 @@ package config import ( + "fmt" "os" "path/filepath" + "strings" - "gopkg.in/yaml.v3" + "github.com/spf13/viper" ) const ( - BaseDir = "/etc/sunhpc" - LogDir = "/var/log/sunhpc" - TmplDir = BaseDir + "/tmpl.d" + BaseDir string = "/etc/sunhpc" + LogDir string = "/var/log/sunhpc" + TmplDir string = BaseDir + "/tmpl.d" + appName string = "sunhpc" + defaultDBPath string = "/var/lib/sunhpc" + defaultDBName string = "sunhpc.db" ) -var ( - SunHPCFile = filepath.Join(BaseDir, "sunhpc.yaml") - NodesFile = filepath.Join(BaseDir, "nodes.yaml") - NetworkFile = filepath.Join(BaseDir, "network.yaml") - DisksFile = filepath.Join(BaseDir, "disks.yaml") - ServicesFile = filepath.Join(BaseDir, "services.yaml") - FirewallFile = filepath.Join(BaseDir, "iptables.yaml") -) +type Config struct { + DB DBConfig `yaml:"db"` + Log LogConfig `yaml:"log"` + Cluster ClusterConfig `yaml:"cluster"` +} + +type DBConfig struct { + Type string `yaml:"type"` + Path string `yaml:"path"` // SQLite: 目录路径 + Name string `yaml:"name"` // SQLite: 文件名 + User string `yaml:"user"` + Password string `yaml:"password"` + Host string `yaml:"host"` + Port int `yaml:"port"` + Socket string `yaml:"socket"` + Verbose bool `yaml:"verbose"` +} + +type LogConfig struct { + Level string `yaml:"level"` + Format string `yaml:"format"` + Output string `yaml:"output"` + FilePath string `yaml:"file_path"` +} + +type ClusterConfig struct { + Name string `yaml:"name"` + AdminEmail string `yaml:"admin_email"` + TimeZone string `yaml:"time_zone"` + NodePrefix string `yaml:"node_prefix"` +} + +// LoadConfig loads configuration with the following precedence: +// 优先级排序: +// 1. 环境变量 (prefix: SUNHPC_) +// 2. ~/.sunhpc.yaml +// 3. ./sunhpc.yaml +// 4. /etc/sunhpc/sunhpc.yaml +// 5. Default values +/* + 示例配置文件: + ```yaml + db: + type: sqlite + name: sunhpc.db + path: /var/lib/sunhpc + socket: /var/lib/sunhpc/mysql/mysqld.sock + user: root + password: "" + host: localhost + ``` + + 环境变量配置示例: + ```bash + export SUNHPC_DATABASE_TYPE=mysql + export SUNHPC_DATABASE_NAME=sunhpc + export SUNHPC_DATABASE_USER=root + export SUNHPC_DATABASE_PASSWORD=123456 + export SUNHPC_DATABASE_HOST=localhost + ``` +*/ +func LoadConfig() (*Config, error) { + v := viper.New() + + // Step 1: 设置默认值(最低优先级) + v.SetDefault("db.type", "sqlite") + v.SetDefault("db.name", "sunhpc.db") + v.SetDefault("db.path", "/var/lib/sunhpc") + v.SetDefault("db.socket", "/var/lib/sunhpc/mysql/mysqld.sock") + v.SetDefault("db.user", "") + v.SetDefault("db.password", "") + v.SetDefault("db.host", "localhost") + v.SetDefault("db.port", 3306) + v.SetDefault("db.verbose", false) + + // Step 2: 启用环境变量(高优先级) + v.SetEnvPrefix("SUNHPC") // e.g., SUNHPC_DATABASE_NAME + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) // db.type -> SUNHPC_DB_TYPE + v.AutomaticEnv() // Auto bind env vars matching config keys + + // Step 3: 按优先级从高到低加载配置文件 + // 优先级: env > ./sunhpc.yaml > ~/.sunhpc.yaml > /etc/sunhpc/sunhpc.yaml > defaults + configFiles := []string{ + "./sunhpc.yaml", + filepath.Join(os.Getenv("HOME"), ".sunhpc.yaml"), + filepath.Join(BaseDir, "sunhpc.yaml"), + } + + var configFile string + for _, cfgFile := range configFiles { + if _, err := os.Stat(cfgFile); err == nil { + configFile = cfgFile + break // 找到第一个就停止. + } + } + + // 如果找到配置文件,就加载它. + if configFile != "" { + v.SetConfigFile(configFile) + if err := v.ReadInConfig(); err != nil { + return nil, fmt.Errorf("加载配置文件 %s 失败: %w", configFile, err) + } + } + + // 解码到结构体 + var cfg Config + if err := v.Unmarshal(&cfg); err != nil { + return nil, fmt.Errorf("解码配置到结构体失败: %w", err) + } + + return &cfg, nil +} // InitDirs 创建所有必需目录 func InitDirs() error { @@ -31,39 +140,8 @@ func InitDirs() error { } for _, d := range dirs { if err := os.MkdirAll(d, 0755); err != nil { - return err + return fmt.Errorf("创建目录 %s 失败: %w", d, err) } } return nil } - -// CreateDefaultConfigs 生成默认 YAML 配置文件 -func CreateDefaultConfigs() error { - files := map[string]interface{}{ - SunHPCFile: DefaultSunHPC(), - NodesFile: DefaultNodes(), - NetworkFile: DefaultNetwork(), - DisksFile: DefaultDisks(), - ServicesFile: DefaultServices(), - FirewallFile: DefaultFirewall(), - } - - for path, data := range files { - if err := writeYAML(path, data); err != nil { - return err - } - } - return nil -} - -func writeYAML(path string, data interface{}) error { - f, err := os.Create(path) - if err != nil { - return err - } - defer f.Close() - - enc := yaml.NewEncoder(f) - defer enc.Close() - return enc.Encode(data) -} diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 354987b..221bac5 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -1,128 +1,60 @@ package config -// SunHPC 主配置 -type SunHPCConfig struct { - Hostname string `yaml:"hostname"` - MOTD string `yaml:"motd"` - Sysctl map[string]string `yaml:"sysctl"` - SELinux string `yaml:"selinux"` // enforcing, permissive, disabled - SSH SSHConfig `yaml:"ssh"` -} +import ( + "os" + "path/filepath" -type SSHConfig struct { - PermitRootLogin string `yaml:"permit_root_login"` - PasswordAuth string `yaml:"password_authentication"` -} + "gopkg.in/yaml.v3" +) -func DefaultSunHPC() *SunHPCConfig { - return &SunHPCConfig{ - Hostname: "sunhpc-master", - MOTD: "Welcome to SunHPC Cluster\n", - Sysctl: map[string]string{ - "net.ipv4.ip_forward": "1", - "vm.swappiness": "10", +// DefaultConfig 返回 SunHPC 的默认配置结构体 +func DefaultConfig() *Config { + return &Config{ + DB: DBConfig{ + Type: "sqlite", + Path: "/var/lib/sunhpc", // SQLite 数据库存放目录 + Name: "sunhpc.db", // 数据库文件名 + User: "", // SQLite 不需要 + Password: "", + Host: "", + Port: 0, + Socket: "", + Verbose: false, }, - SELinux: "enforcing", - SSH: SSHConfig{ - PermitRootLogin: "yes", - PasswordAuth: "yes", + Log: LogConfig{ + Level: "info", + Format: "text", // or "json" + Output: "stdout", + FilePath: "/var/log/sunhpc/sunhpc.log", + }, + Cluster: ClusterConfig{ + Name: "default-cluster", + AdminEmail: "admin@example.com", + TimeZone: "Asia/Shanghai", + NodePrefix: "node", }, } } -// Nodes 节点配置 -type NodesConfig struct { - Nodes []Node `yaml:"nodes"` -} - -type Node struct { - Hostname string `yaml:"hostname"` - MAC string `yaml:"mac"` - IP string `yaml:"ip"` - Role string `yaml:"role"` // master, compute, login -} - -func DefaultNodes() *NodesConfig { - return &NodesConfig{ - Nodes: []Node{ - {Hostname: "master", MAC: "00:11:22:33:44:55", IP: "192.168.1.1", Role: "master"}, - }, +// WriteDefaultConfig 将默认配置写入指定路径 +// 如果目录不存在,会自动创建(需有权限) +// 如果文件已存在且非空,会返回错误(除非调用方先删除) +func WriteDefaultConfig(path string) error { + // 确保目录存在 + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return err } -} -// Network 网络配置 -type NetworkConfig struct { - Interface string `yaml:"interface"` - Subnet string `yaml:"subnet"` - Netmask string `yaml:"netmask"` - Gateway string `yaml:"gateway"` - DNSServers []string `yaml:"dns_servers"` -} + // 生成默认配置 + cfg := DefaultConfig() -func DefaultNetwork() *NetworkConfig { - return &NetworkConfig{ - Interface: "eth0", - Subnet: "192.168.1.0", - Netmask: "255.255.255.0", - Gateway: "192.168.1.1", - DNSServers: []string{"8.8.8.8", "114.114.114.114"}, + // 序列化为 YAML + data, err := yaml.Marshal(cfg) + if err != nil { + return err } -} -// Disks 磁盘配置 -type DisksConfig struct { - Disks []Disk `yaml:"disks"` -} - -type Disk struct { - Device string `yaml:"device"` - Mount string `yaml:"mount"` - FSType string `yaml:"fstype"` - Options string `yaml:"options"` -} - -func DefaultDisks() *DisksConfig { - return &DisksConfig{ - Disks: []Disk{ - {Device: "/dev/sda1", Mount: "/", FSType: "ext4", Options: "defaults"}, - }, - } -} - -// Services 服务配置 -type ServicesConfig struct { - HTTPD Service `yaml:"httpd"` - TFTPD Service `yaml:"tftpd"` - DHCPD Service `yaml:"dhcpd"` -} - -type Service struct { - Enabled bool `yaml:"enabled"` - Config string `yaml:"config,omitempty"` -} - -func DefaultServices() *ServicesConfig { - return &ServicesConfig{ - HTTPD: Service{Enabled: true}, - TFTPD: Service{Enabled: true}, - DHCPD: Service{Enabled: true}, - } -} - -// Firewall 防火墙配置 -type FirewallConfig struct { - DefaultPolicy string `yaml:"default_policy"` - Rules []string `yaml:"rules"` -} - -func DefaultFirewall() *FirewallConfig { - return &FirewallConfig{ - DefaultPolicy: "DROP", - Rules: []string{ - "-A INPUT -m state --state ESTABLISHED,RELATED -j ACCEPT", - "-A INPUT -p icmp -j ACCEPT", - "-A INPUT -i lo -j ACCEPT", - "-A INPUT -p tcp --dport 22 -j ACCEPT", - }, - } + // 写入文件(0644 权限) + return os.WriteFile(path, data, 0644) } diff --git a/internal/config/loaders.go b/internal/config/loaders.go deleted file mode 100644 index 2e14686..0000000 --- a/internal/config/loaders.go +++ /dev/null @@ -1,43 +0,0 @@ -package config - -import ( - "os" - - "gopkg.in/yaml.v3" -) - -func LoadSunHPC() (*SunHPCConfig, error) { - return loadYAML[SunHPCConfig](SunHPCFile) -} - -func LoadNodes() (*NodesConfig, error) { - return loadYAML[NodesConfig](NodesFile) -} - -func LoadNetwork() (*NetworkConfig, error) { - return loadYAML[NetworkConfig](NetworkFile) -} - -func LoadDisks() (*DisksConfig, error) { - return loadYAML[DisksConfig](DisksFile) -} - -func LoadServices() (*ServicesConfig, error) { - return loadYAML[ServicesConfig](ServicesFile) -} - -func LoadFirewall() (*FirewallConfig, error) { - return loadYAML[FirewallConfig](FirewallFile) -} - -func loadYAML[T any](path string) (*T, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - var cfg T - if err := yaml.Unmarshal(data, &cfg); err != nil { - return nil, err - } - return &cfg, nil -} diff --git a/internal/db/db.go b/internal/db/db.go index d8e774a..d365523 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,794 +1,128 @@ package db import ( - "context" "database/sql" "fmt" - "io/ioutil" "os" "path/filepath" - "strings" "sync" - "time" - - "sunhpc/internal/log" + _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" + + "sunhpc/internal/config" + "sunhpc/internal/log" ) -// 全局单例 +// DB wraps the sql.DB connection pool. +type DB struct { + engine *sql.DB +} + +// Engine returns the underlying *sql.DB. +func (d *DB) Engine() *sql.DB { + return d.engine +} + +// InitSchema initializes the database schema. +// If force is true, drops existing tables before recreating them. +func (d *DB) InitSchema(force bool) error { + db := d.engine + + if force { + if err := dropTables(db); err != nil { + return fmt.Errorf("failed to drop tables: %w", err) + } + } + + // ✅ 调用 schema.go 中的函数 + for _, ddl := range CreateTableStatements() { + if _, err := db.Exec(ddl); err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + } + + return nil +} + +func dropTables(db *sql.DB) error { + // ✅ 调用 schema.go 中的函数 + for _, table := range DropTableOrder() { + if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil { + return err + } + } + return nil +} + +// --- Singleton DB Instance --- var ( globalDB *DB - once sync.Once + initOnce sync.Once + initErr error ) -// DB 核心数据库类 - 对应Rocks的Database类 -type DB struct { - // 连接参数 - dbUser string - dbPasswd string - dbHost string - dbName string - dbPath string - dbSocket string - verbose bool - forceInit bool - - // 连接对象 - engine *sql.DB // 连接池 - conn *sql.Conn // 当前连接 - results *sql.Rows // 当前结果集 - - // 线程本地存储模拟 - sessions sync.Map - - mu sync.RWMutex -} - -// NewDB 创建新实例 -func NewDB() *DB { - return &DB{ - dbUser: "", - dbPasswd: "", - dbHost: "localhost", - dbName: "sunhpc", - dbPath: "/var/lib/sunhpc", - dbSocket: "/var/lib/sunhpc/mysql/mysql.sock", - verbose: false, - } -} - -// ==================== 连接参数设置/获取 ==================== - -func (db *DB) SetDBPasswd(passwd string) { - db.mu.Lock() - defer db.mu.Unlock() - db.dbPasswd = passwd -} - -func (db *DB) GetDBPasswd() string { - db.mu.RLock() - if db.dbPasswd != "" { - db.mu.RUnlock() - return db.dbPasswd - } - db.mu.RUnlock() - - db.mu.Lock() - defer db.mu.Unlock() - - // 从配置文件读取密码 - username := db.GetDBUsername() - var filename string - switch username { - case "root": - filename = "/root/.sunhpc.my.cnf" - default: - filename = fmt.Sprintf("/home/%s/.sunhpc.my.cnf", username) - } - - data, err := ioutil.ReadFile(filename) - if err != nil { - return "" - } - - lines := strings.Split(string(data), "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - parts := strings.Split(line, "=") - if len(parts) == 2 && strings.TrimSpace(parts[0]) == "password" { - db.dbPasswd = strings.TrimSpace(parts[1]) - break - } - } - return db.dbPasswd -} - -func (db *DB) SetDBUsername(name string) { - db.mu.Lock() - defer db.mu.Unlock() - db.dbUser = name -} - -func (db *DB) GetDBUsername() string { - db.mu.RLock() - if db.dbUser != "" { - db.mu.RUnlock() - return db.dbUser - } - db.mu.RUnlock() - - db.mu.Lock() - defer db.mu.Unlock() - db.dbUser = os.Getenv("USER") - return db.dbUser -} - -func (db *DB) SetDBHostname(host string) { - db.mu.Lock() - defer db.mu.Unlock() - db.dbHost = host -} - -func (db *DB) GetDBHostname() string { - db.mu.RLock() - defer db.mu.RUnlock() - return db.dbHost -} - -func (db *DB) SetDBName(name string) { - db.mu.Lock() - defer db.mu.Unlock() - db.dbName = name -} - -func (db *DB) GetDBName() string { - db.mu.RLock() - defer db.mu.RUnlock() - return db.dbName -} - -func (db *DB) SetDBPath(path string) { - db.mu.Lock() - defer db.mu.Unlock() - db.dbPath = path -} - -func (db *DB) GetDBPath() string { - db.mu.RLock() - defer db.mu.RUnlock() - return db.dbPath -} - -func (db *DB) SetVerbose(verbose bool) { - db.mu.Lock() - defer db.mu.Unlock() - db.verbose = verbose -} - -func (db *DB) SetForceInit(force bool) { - db.mu.Lock() - defer db.mu.Unlock() - db.forceInit = force -} - -// ==================== 连接管理 ==================== - -// Connect 连接数据库 -func (db *DB) Connect() error { - log.Debug("连接数据库...") - db.mu.Lock() - defer db.mu.Unlock() - - log.Debug("检查 SUNHPCDEBUG 环境变量...") - if os.Getenv("SUNHPCDEBUG") != "" { - db.verbose = true - } - - // 使用SQLite - dbFullPath := filepath.Join(db.dbPath, db.dbName+".db") - log.Debugf("数据库路径: %s", dbFullPath) - - // 确保目录存在 - log.Debug("确保数据库目录存在...") - os.MkdirAll(db.dbPath, 0755) - - engine, err := sql.Open("sqlite3", dbFullPath+"?_foreign_keys=on&_journal_mode=WAL") - log.Debugf("打开数据库连接...") - if err != nil { - return fmt.Errorf("打开数据库失败: %v", err) - } - - engine.SetMaxOpenConns(10) - engine.SetMaxIdleConns(5) - engine.SetConnMaxLifetime(time.Hour) - - db.engine = engine - - conn, err := engine.Conn(context.Background()) - log.Debugf("获取数据库连接...") - if err != nil { - return fmt.Errorf("获取连接失败: %v", err) - } - db.conn = conn - - // 初始化数据库表 - if err := db.initSchema(); err != nil { - return fmt.Errorf("初始化数据库表失败: %v", err) - } - - if db.verbose { - log.Infof("数据库连接成功: %s", dbFullPath) - } - - return nil -} - -// initSchema 初始化数据库表结构 - 所有表定义在这里 -func (db *DB) initSchema() error { - log.Debug("初始化数据库表结构...") - - // 检查 nodes 表是否已存在 - var tableName string - err := db.engine.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='nodes'").Scan(&tableName) - - if err == nil && !db.forceInit { - log.Debug("数据库表已存在,跳过初始化") - return nil - } - - if db.forceInit { - log.Warn("强制重新初始化数据库表结构...") - } else { - log.Info("首次初始化数据库表结构...") - } - - // 如果强制初始化,先删除所有表 - if db.forceInit { - log.Info("删除现有表...") - dropSQLs := []string{ - `DROP TABLE IF EXISTS resolvechain;`, - `DROP TABLE IF EXISTS hostselections;`, - `DROP TABLE IF EXISTS attributes;`, - `DROP TABLE IF EXISTS catindexes;`, - `DROP TABLE IF EXISTS categories;`, - `DROP TABLE IF EXISTS node_attrs;`, - `DROP TABLE IF EXISTS aliases;`, - `DROP TABLE IF EXISTS networks;`, - `DROP TABLE IF EXISTS subnets;`, - `DROP TABLE IF EXISTS software_installs;`, - `DROP TABLE IF EXISTS memberships;`, - `DROP TABLE IF EXISTS appliances;`, - `DROP TABLE IF EXISTS nodes;`, +func GetDB() (*DB, error) { + initOnce.Do(func() { + cfg, err := config.LoadConfig() + if err != nil { + initErr = fmt.Errorf("数据库配置文件加载失败: %w", err) + return } - for _, sql := range dropSQLs { - if _, err := db.engine.Exec(sql); err != nil { - log.Warnf("删除表失败: %v", err) + if _, err := os.Stat(cfg.DB.Path); err != nil { + // 创建数据库目录 + if err := os.MkdirAll(cfg.DB.Path, 0755); err != nil { + log.Fatalf("创建数据库目录失败: %v", err) } - } - log.Info("现有表已删除") - } - - // 开启事务 - tx, err := db.engine.Begin() - if err != nil { - return fmt.Errorf("开启事务失败: %v", err) - } - - // 使用exec执行,每条SQL单独执行 - sqls := []string{ - // 创建表 - 注意创建顺序(先创建主表,再创建有外键的表) - `CREATE TABLE IF NOT EXISTS nodes ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - rack INTEGER DEFAULT 0, - rank INTEGER DEFAULT 0, - membership_id INTEGER, - cpus INTEGER DEFAULT 0, - memory INTEGER DEFAULT 0, - disk INTEGER DEFAULT 0, - os TEXT, - kernel TEXT, - last_state_change DATETIME DEFAULT CURRENT_TIMESTAMP, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - );`, - - `CREATE TABLE IF NOT EXISTS appliances ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - description TEXT, - node_type TEXT DEFAULT 'compute' - );`, - - `CREATE TABLE IF NOT EXISTS memberships ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - appliance_id INTEGER, - FOREIGN KEY (appliance_id) REFERENCES appliances(id) ON DELETE SET NULL - );`, - - `CREATE TABLE IF NOT EXISTS subnets ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT UNIQUE, - network TEXT, - netmask TEXT, - gateway TEXT, - dns_zone TEXT, - is_private INTEGER DEFAULT 1 - );`, - - `CREATE TABLE IF NOT EXISTS networks ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - node_id INTEGER NOT NULL, - name TEXT, - ip TEXT UNIQUE, - mac TEXT UNIQUE, - subnet_id INTEGER, - interface TEXT DEFAULT 'eth0', - FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, - FOREIGN KEY (subnet_id) REFERENCES subnets(id) ON DELETE SET NULL - );`, - - `CREATE TABLE IF NOT EXISTS aliases ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - node_id INTEGER NOT NULL, - name TEXT NOT NULL, - FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, - UNIQUE(node_id, name) - );`, - - `CREATE TABLE IF NOT EXISTS categories ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE - );`, - - `CREATE TABLE IF NOT EXISTS catindexes ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - category_id INTEGER NOT NULL, - FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, - UNIQUE(name, category_id) - );`, - - `CREATE TABLE IF NOT EXISTS attributes ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - attr TEXT NOT NULL, - value TEXT, - category_id INTEGER NOT NULL, - catindex_id INTEGER NOT NULL, - FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, - FOREIGN KEY (catindex_id) REFERENCES catindexes(id) ON DELETE CASCADE, - UNIQUE(attr, category_id, catindex_id) - );`, - - `CREATE TABLE IF NOT EXISTS node_attrs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - node_id INTEGER NOT NULL, - attr TEXT NOT NULL, - value TEXT, - FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, - UNIQUE(node_id, attr) - );`, - - `CREATE TABLE IF NOT EXISTS hostselections ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - host_id INTEGER NOT NULL, - category_id INTEGER NOT NULL, - selection TEXT NOT NULL, - FOREIGN KEY (host_id) REFERENCES nodes(id) ON DELETE CASCADE, - FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, - UNIQUE(host_id, category_id, selection) - );`, - - `CREATE TABLE IF NOT EXISTS resolvechain ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - category_id INTEGER NOT NULL, - precedence INTEGER NOT NULL, - FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, - UNIQUE(category_id, precedence) - );`, - - `CREATE TABLE IF NOT EXISTS software_installs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - version TEXT, - install_type TEXT, - node_id INTEGER, - status TEXT, - installed_at DATETIME DEFAULT CURRENT_TIMESTAMP, - installed_by TEXT, - FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE SET NULL - );`, - - // 创建索引 - `CREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name);`, - `CREATE INDEX IF NOT EXISTS idx_networks_ip ON networks(ip);`, - `CREATE INDEX IF NOT EXISTS idx_networks_mac ON networks(mac);`, - `CREATE INDEX IF NOT EXISTS idx_attributes_lookup ON attributes(attr, category_id, catindex_id);`, - `CREATE INDEX IF NOT EXISTS idx_node_attrs_lookup ON node_attrs(node_id, attr);`, - `CREATE INDEX IF NOT EXISTS idx_hostselections_host ON hostselections(host_id);`, - `CREATE INDEX IF NOT EXISTS idx_resolvechain_precedence ON resolvechain(precedence);`, - } - - // 逐条执行SQL - for i, sql := range sqls { - if strings.TrimSpace(sql) == "" { - continue + log.Infof("数据库目录创建成功: %s", cfg.DB.Path) } - log.Debugf("执行SQL[%d]: %s", i, strings.TrimSpace(strings.Split(sql, "\n")[0])) + var dsn string + var driver string - _, err := tx.Exec(sql) - if err != nil { - tx.Rollback() - return fmt.Errorf("执行SQL[%d]失败: %v\nSQL: %s", i, err, sql) - } - } - - // 提交事务 - if err := tx.Commit(); err != nil { - return fmt.Errorf("提交事务失败: %v", err) - } - - log.Info("数据库表结构创建成功") - - // 插入默认数据 - return db.insertDefaultData() -} - -// insertDefaultData 插入默认数据 -func (db *DB) insertDefaultData() error { - log.Debug("插入默认数据...") - // 默认类别 - categories := []string{"global", "host", "os", "appliance", "network"} - for _, cat := range categories { - _, err := db.engine.Exec( - "INSERT OR IGNORE INTO categories (name) VALUES (?)", - cat, - ) - if err != nil { - return err - } - } - - log.Debug("插入默认类别索引...") - // 默认类别索引 - catIndexes := []struct { - catName string - idxName string - }{ - {"global", "global"}, - {"os", "linux"}, - {"network", "private"}, - } - - for _, ci := range catIndexes { - _, err := db.engine.Exec(` - INSERT OR IGNORE INTO catindexes (name, category_id) - SELECT ?, id FROM categories WHERE name = ? - `, ci.idxName, ci.catName) - if err != nil { - return err - } - } - - log.Debug("插入默认解析链优先级...") - // 默认解析链优先级 - precedence := []struct { - catName string - level int - }{ - {"global", 1}, - {"os", 2}, - {"appliance", 3}, - {"host", 4}, - {"network", 5}, - } - - for _, p := range precedence { - _, err := db.engine.Exec(` - INSERT OR IGNORE INTO resolvechain (category_id, precedence) - SELECT id, ? FROM categories WHERE name = ? - `, p.level, p.catName) - if err != nil { - return err - } - } - - log.Debug("插入默认设备类型...") - // 默认设备类型 - appliances := []struct { - name string - desc string - typ string - }{ - {"frontend", "管理节点", "master"}, - {"compute", "计算节点", "compute"}, - {"login", "登录节点", "login"}, - {"storage", "存储节点", "storage"}, - } - - for _, a := range appliances { - _, err := db.engine.Exec( - "INSERT OR IGNORE INTO appliances (name, description, node_type) VALUES (?, ?, ?)", - a.name, a.desc, a.typ, - ) - if err != nil { - return err - } - } - - log.Debug("插入默认数据完成...") - return nil -} - -// ==================== 核心查询方法 ==================== - -// Execute 执行SQL语句 - 对应Rocks的execute() -func (db *DB) Execute(query string, args ...interface{}) (int64, error) { - db.mu.RLock() - conn := db.conn - verbose := db.verbose - db.mu.RUnlock() - - if conn == nil { - return 0, fmt.Errorf("没有活动数据库连接") - } - - if verbose { - log.Debugf("执行SQL: %s %v", query, args) - } - - // 判断SQL类型 - upperQuery := strings.ToUpper(strings.TrimSpace(query)) - isSelect := strings.HasPrefix(upperQuery, "SELECT") - - if isSelect { - // SELECT 查询使用 QueryContext - rows, err := conn.QueryContext(context.Background(), query, args...) - if err != nil { - // 尝试重连一次 - db.RenewConnection() - db.mu.RLock() - conn = db.conn - db.mu.RUnlock() - rows, err = conn.QueryContext(context.Background(), query, args...) - } - - if err != nil { - return 0, err - } - - // 关闭旧结果 - db.mu.Lock() - if db.results != nil { - db.results.Close() - } - db.results = rows - db.mu.Unlock() - - return 0, nil - } else { - // INSERT/UPDATE/DELETE 使用 Exec(自动提交) - result, err := conn.ExecContext(context.Background(), query, args...) - if err != nil { - // 尝试重连一次 - db.RenewConnection() - db.mu.RLock() - conn = db.conn - db.mu.RUnlock() - result, err = conn.ExecContext(context.Background(), query, args...) - } - - if err != nil { - return 0, err - } - - // 获取影响行数 - rowsAffected, err := result.RowsAffected() - if err != nil { - return 0, err - } - - if verbose { - log.Debugf("影响行数: %d", rowsAffected) - } - - return rowsAffected, nil - } -} - -// FetchOne 获取一行 - 对应Rocks的fetchone() -// 返回map[string]interface{}格式,key为列名 -func (db *DB) FetchOne() (map[string]interface{}, error) { - db.mu.RLock() - results := db.results - db.mu.RUnlock() - - if results == nil { - return nil, nil - } - - if !results.Next() { - return nil, nil - } - - columns, err := results.Columns() - if err != nil { - return nil, err - } - - values := make([]interface{}, len(columns)) - scanArgs := make([]interface{}, len(columns)) - for i := range values { - scanArgs[i] = &values[i] - } - - err = results.Scan(scanArgs...) - if err != nil { - return nil, err - } - - row := make(map[string]interface{}) - for i, col := range columns { - val := values[i] - if b, ok := val.([]byte); ok { - row[col] = string(b) - } else { - row[col] = val - } - } - - return row, nil -} - -// FetchAll 获取所有行 - 对应Rocks的fetchall() -// 返回[]map[string]interface{}格式 -func (db *DB) FetchAll() ([]map[string]interface{}, error) { - db.mu.RLock() - results := db.results - db.mu.RUnlock() - - if results == nil { - return nil, nil - } - - columns, err := results.Columns() - if err != nil { - return nil, err - } - - var rows []map[string]interface{} - - for results.Next() { - values := make([]interface{}, len(columns)) - scanArgs := make([]interface{}, len(columns)) - for i := range values { - scanArgs[i] = &values[i] - } - - err = results.Scan(scanArgs...) - if err != nil { - return nil, err - } - - row := make(map[string]interface{}) - for i, col := range columns { - val := values[i] - if b, ok := val.([]byte); ok { - row[col] = string(b) + switch cfg.DB.Type { + case "sqlite": + driver = "sqlite3" + fullPath := filepath.Join(cfg.DB.Path, cfg.DB.Name) + dsn = fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL&_timeout=5000", fullPath) + case "mysql": + driver = "mysql" + if cfg.DB.Socket != "" { + dsn = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=true&loc=Local", + cfg.DB.User, cfg.DB.Password, cfg.DB.Socket, cfg.DB.Name) } else { - row[col] = val + dsn = fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true&loc=Local", + cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Name) } + default: + initErr = fmt.Errorf("unsupported database type: %s", cfg.DB.Type) + return } - rows = append(rows, row) - } - - return rows, nil -} - -// ==================== 连接维护 ==================== - -// RenewConnection 续期连接 -func (db *DB) RenewConnection() error { - db.mu.Lock() - defer db.mu.Unlock() - - if db.conn != nil { - db.conn.Close() - } - - conn, err := db.engine.Conn(context.Background()) - if err != nil { - return err - } - db.conn = conn - return nil -} - -// Close 关闭连接 -func (db *DB) Close() error { - db.mu.Lock() - defer db.mu.Unlock() - - if db.results != nil { - db.results.Close() - db.results = nil - } - if db.conn != nil { - db.conn.Close() - db.conn = nil - } - if db.engine != nil { - return db.engine.Close() - } - return nil -} - -// CloseConnection 只关闭当前连接,不关闭连接池 -func (db *DB) CloseConnection() error { - db.mu.Lock() - defer db.mu.Unlock() - - if db.results != nil { - db.results.Close() - db.results = nil - } - if db.conn != nil { - db.conn.Close() - db.conn = nil - } - return nil -} - -// ==================== 单例模式 ==================== - -var ( - instanceConfigured bool - instanceDBPath string - instanceDBName string -) - -func GetInstance() (*DB, error) { - return GetInstanceWithConfig("", "") -} - -func GetInstanceWithConfig(dbPath, dbName string) (*DB, error) { - var err error - once.Do(func() { - globalDB = NewDB() - log.Debug("创建数据库实例...") - globalDB.SetDBUsername(globalDB.GetDBUsername()) - - if dbPath != "" { - globalDB.SetDBPath(dbPath) - log.Debugf("设置数据库路径: %s", dbPath) - } - if dbName != "" { - globalDB.SetDBName(dbName) - log.Debugf("设置数据库名称: %s", dbName) + engine, err := sql.Open(driver, dsn) + if err != nil { + initErr = fmt.Errorf("failed to open database: %w", err) + return } - instanceConfigured = (dbPath != "" || dbName != "") - if dbPath != "" { - instanceDBPath = dbPath - } - if dbName != "" { - instanceDBName = dbName + if err := engine.Ping(); err != nil { + engine.Close() + initErr = fmt.Errorf("failed to ping database: %w", err) + return } - err = globalDB.Connect() + globalDB = &DB{engine: engine} }) - return globalDB, err + + return globalDB, initErr } -func IsInstanceConfigured() bool { - return instanceConfigured -} - -func GetInstanceConfig() (dbPath, dbName string) { - return instanceDBPath, instanceDBName +// MustGetDB is a helper that panics on error (use in main/init only). +func MustGetDB() *DB { + db, err := GetDB() + if err != nil { + log.Fatalf("数据库初始化失败: %v", err) + } + return db } diff --git a/internal/db/db.txt b/internal/db/db.txt new file mode 100644 index 0000000..c1182ef --- /dev/null +++ b/internal/db/db.txt @@ -0,0 +1,637 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "sunhpc/internal/log" + + _ "github.com/mattn/go-sqlite3" +) + +// 全局单例 +var ( + globalDB *DB + once sync.Once +) + +// DB 核心数据库类 - 对应Rocks的Database类 +type DB struct { + // 连接参数 + dbUser string + dbPasswd string + dbHost string + dbName string + dbPath string + dbSocket string + verbose bool + forceInit bool + + // 连接对象 + engine *sql.DB // 连接池 + conn *sql.Conn // 当前连接 + results *sql.Rows // 当前结果集 + + // 线程本地存储模拟 + sessions sync.Map + + mu sync.RWMutex +} + +// NewDB 创建新实例 +func NewDB() *DB { + return &DB{ + dbUser: "", + dbPasswd: "", + dbHost: "localhost", + dbName: "sunhpc", + dbPath: "/var/lib/sunhpc", + dbSocket: "/var/lib/sunhpc/mysql/mysql.sock", + verbose: false, + } +} + +// ==================== 连接管理 ==================== +// Connect 连接数据库 +func (db *DB) Connect() error { + log.Debug("连接数据库...") + db.mu.Lock() + defer db.mu.Unlock() + + log.Debug("检查 SUNHPCDEBUG 环境变量...") + if os.Getenv("SUNHPCDEBUG") != "" { + db.verbose = true + } + + // 使用SQLite + dbFullPath := filepath.Join(db.dbPath, db.dbName+".db") + log.Debugf("数据库路径: %s", dbFullPath) + + // 确保目录存在 + log.Debug("确保数据库目录存在...") + os.MkdirAll(db.dbPath, 0755) + + engine, err := sql.Open("sqlite3", dbFullPath+"?_foreign_keys=on&_journal_mode=WAL") + log.Debugf("打开数据库连接...") + if err != nil { + return fmt.Errorf("打开数据库失败: %v", err) + } + + engine.SetMaxOpenConns(10) + engine.SetMaxIdleConns(5) + engine.SetConnMaxLifetime(time.Hour) + + db.engine = engine + + conn, err := engine.Conn(context.Background()) + log.Debugf("获取数据库连接...") + if err != nil { + return fmt.Errorf("获取连接失败: %v", err) + } + db.conn = conn + + // 初始化数据库表 + if err := db.initSchema(); err != nil { + return fmt.Errorf("初始化数据库表失败: %v", err) + } + + if db.verbose { + log.Infof("数据库连接成功: %s", dbFullPath) + } + + return nil +} + +// initSchema 初始化数据库表结构 - 所有表定义在这里 +func (db *DB) initSchema() error { + log.Debug("初始化数据库表结构...") + + // 检查 nodes 表是否已存在 + var tableName string + err := db.engine.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='nodes'").Scan(&tableName) + + if err == nil && !db.forceInit { + log.Debug("数据库表已存在,跳过初始化") + return nil + } + + if db.forceInit { + log.Warn("强制重新初始化数据库表结构...") + } else { + log.Info("首次初始化数据库表结构...") + } + + // 如果强制初始化,先删除所有表 + if db.forceInit { + log.Info("删除现有表...") + dropSQLs := []string{ + `DROP TABLE IF EXISTS resolvechain;`, + `DROP TABLE IF EXISTS hostselections;`, + `DROP TABLE IF EXISTS attributes;`, + `DROP TABLE IF EXISTS catindexes;`, + `DROP TABLE IF EXISTS categories;`, + `DROP TABLE IF EXISTS node_attrs;`, + `DROP TABLE IF EXISTS aliases;`, + `DROP TABLE IF EXISTS networks;`, + `DROP TABLE IF EXISTS subnets;`, + `DROP TABLE IF EXISTS software_installs;`, + `DROP TABLE IF EXISTS memberships;`, + `DROP TABLE IF EXISTS appliances;`, + `DROP TABLE IF EXISTS nodes;`, + } + + for _, sql := range dropSQLs { + if _, err := db.engine.Exec(sql); err != nil { + log.Warnf("删除表失败: %v", err) + } + } + log.Info("现有表已删除") + } + + // 开启事务 + tx, err := db.engine.Begin() + if err != nil { + return fmt.Errorf("开启事务失败: %v", err) + } + + // 使用exec执行,每条SQL单独执行 + sqls := []string{ + // 创建表 - 注意创建顺序(先创建主表,再创建有外键的表) + `CREATE TABLE IF NOT EXISTS nodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + rack INTEGER DEFAULT 0, + rank INTEGER DEFAULT 0, + membership_id INTEGER, + cpus INTEGER DEFAULT 0, + memory INTEGER DEFAULT 0, + disk INTEGER DEFAULT 0, + os TEXT, + kernel TEXT, + last_state_change DATETIME DEFAULT CURRENT_TIMESTAMP, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + );`, + + `CREATE TABLE IF NOT EXISTS appliances ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + description TEXT, + node_type TEXT DEFAULT 'compute' + );`, + + `CREATE TABLE IF NOT EXISTS memberships ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + appliance_id INTEGER, + FOREIGN KEY (appliance_id) REFERENCES appliances(id) ON DELETE SET NULL + );`, + + `CREATE TABLE IF NOT EXISTS subnets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE, + network TEXT, + netmask TEXT, + gateway TEXT, + dns_zone TEXT, + is_private INTEGER DEFAULT 1 + );`, + + `CREATE TABLE IF NOT EXISTS networks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + node_id INTEGER NOT NULL, + name TEXT, + ip TEXT UNIQUE, + mac TEXT UNIQUE, + subnet_id INTEGER, + interface TEXT DEFAULT 'eth0', + FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, + FOREIGN KEY (subnet_id) REFERENCES subnets(id) ON DELETE SET NULL + );`, + + `CREATE TABLE IF NOT EXISTS aliases ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + node_id INTEGER NOT NULL, + name TEXT NOT NULL, + FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, + UNIQUE(node_id, name) + );`, + + `CREATE TABLE IF NOT EXISTS categories ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE + );`, + + `CREATE TABLE IF NOT EXISTS catindexes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + category_id INTEGER NOT NULL, + FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, + UNIQUE(name, category_id) + );`, + + `CREATE TABLE IF NOT EXISTS attributes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + attr TEXT NOT NULL, + value TEXT, + category_id INTEGER NOT NULL, + catindex_id INTEGER NOT NULL, + FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, + FOREIGN KEY (catindex_id) REFERENCES catindexes(id) ON DELETE CASCADE, + UNIQUE(attr, category_id, catindex_id) + );`, + + `CREATE TABLE IF NOT EXISTS node_attrs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + node_id INTEGER NOT NULL, + attr TEXT NOT NULL, + value TEXT, + FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE, + UNIQUE(node_id, attr) + );`, + + `CREATE TABLE IF NOT EXISTS hostselections ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + host_id INTEGER NOT NULL, + category_id INTEGER NOT NULL, + selection TEXT NOT NULL, + FOREIGN KEY (host_id) REFERENCES nodes(id) ON DELETE CASCADE, + FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, + UNIQUE(host_id, category_id, selection) + );`, + + `CREATE TABLE IF NOT EXISTS resolvechain ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + category_id INTEGER NOT NULL, + precedence INTEGER NOT NULL, + FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE, + UNIQUE(category_id, precedence) + );`, + + `CREATE TABLE IF NOT EXISTS software_installs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + version TEXT, + install_type TEXT, + node_id INTEGER, + status TEXT, + installed_at DATETIME DEFAULT CURRENT_TIMESTAMP, + installed_by TEXT, + FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE SET NULL + );`, + + // 创建索引 + `CREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name);`, + `CREATE INDEX IF NOT EXISTS idx_networks_ip ON networks(ip);`, + `CREATE INDEX IF NOT EXISTS idx_networks_mac ON networks(mac);`, + `CREATE INDEX IF NOT EXISTS idx_attributes_lookup ON attributes(attr, category_id, catindex_id);`, + `CREATE INDEX IF NOT EXISTS idx_node_attrs_lookup ON node_attrs(node_id, attr);`, + `CREATE INDEX IF NOT EXISTS idx_hostselections_host ON hostselections(host_id);`, + `CREATE INDEX IF NOT EXISTS idx_resolvechain_precedence ON resolvechain(precedence);`, + } + + // 逐条执行SQL + for i, sql := range sqls { + if strings.TrimSpace(sql) == "" { + continue + } + + log.Debugf("执行SQL[%d]: %s", i, strings.TrimSpace(strings.Split(sql, "\n")[0])) + + _, err := tx.Exec(sql) + if err != nil { + tx.Rollback() + return fmt.Errorf("执行SQL[%d]失败: %v\nSQL: %s", i, err, sql) + } + } + + // 提交事务 + if err := tx.Commit(); err != nil { + return fmt.Errorf("提交事务失败: %v", err) + } + + log.Info("数据库表结构创建成功") + + // 插入默认数据 + return db.insertDefaultData() +} + +// insertDefaultData 插入默认数据 +func (db *DB) insertDefaultData() error { + log.Debug("插入默认数据...") + // 默认类别 + categories := []string{"global", "host", "os", "appliance", "network"} + for _, cat := range categories { + _, err := db.engine.Exec( + "INSERT OR IGNORE INTO categories (name) VALUES (?)", + cat, + ) + if err != nil { + return err + } + } + + log.Debug("插入默认类别索引...") + // 默认类别索引 + catIndexes := []struct { + catName string + idxName string + }{ + {"global", "global"}, + {"os", "linux"}, + {"network", "private"}, + } + + for _, ci := range catIndexes { + _, err := db.engine.Exec(` + INSERT OR IGNORE INTO catindexes (name, category_id) + SELECT ?, id FROM categories WHERE name = ? + `, ci.idxName, ci.catName) + if err != nil { + return err + } + } + + log.Debug("插入默认解析链优先级...") + // 默认解析链优先级 + precedence := []struct { + catName string + level int + }{ + {"global", 1}, + {"os", 2}, + {"appliance", 3}, + {"host", 4}, + {"network", 5}, + } + + for _, p := range precedence { + _, err := db.engine.Exec(` + INSERT OR IGNORE INTO resolvechain (category_id, precedence) + SELECT id, ? FROM categories WHERE name = ? + `, p.level, p.catName) + if err != nil { + return err + } + } + + log.Debug("插入默认设备类型...") + // 默认设备类型 + appliances := []struct { + name string + desc string + typ string + }{ + {"frontend", "管理节点", "master"}, + {"compute", "计算节点", "compute"}, + {"login", "登录节点", "login"}, + {"storage", "存储节点", "storage"}, + } + + for _, a := range appliances { + _, err := db.engine.Exec( + "INSERT OR IGNORE INTO appliances (name, description, node_type) VALUES (?, ?, ?)", + a.name, a.desc, a.typ, + ) + if err != nil { + return err + } + } + + log.Debug("插入默认数据完成...") + return nil +} + +// ==================== 核心查询方法 ==================== + +// Execute 执行SQL语句 - 对应Rocks的execute() +func (db *DB) Execute(query string, args ...interface{}) (int64, error) { + db.mu.RLock() + conn := db.conn + verbose := db.verbose + db.mu.RUnlock() + + if conn == nil { + return 0, fmt.Errorf("没有活动数据库连接") + } + + if verbose { + log.Debugf("执行SQL: %s %v", query, args) + } + + // 判断SQL类型 + upperQuery := strings.ToUpper(strings.TrimSpace(query)) + isSelect := strings.HasPrefix(upperQuery, "SELECT") + + if isSelect { + // SELECT 查询使用 QueryContext + rows, err := conn.QueryContext(context.Background(), query, args...) + if err != nil { + // 尝试重连一次 + db.RenewConnection() + db.mu.RLock() + conn = db.conn + db.mu.RUnlock() + rows, err = conn.QueryContext(context.Background(), query, args...) + } + + if err != nil { + return 0, err + } + + // 关闭旧结果 + db.mu.Lock() + if db.results != nil { + db.results.Close() + } + db.results = rows + db.mu.Unlock() + + return 0, nil + } else { + // INSERT/UPDATE/DELETE 使用 Exec(自动提交) + result, err := conn.ExecContext(context.Background(), query, args...) + if err != nil { + // 尝试重连一次 + db.RenewConnection() + db.mu.RLock() + conn = db.conn + db.mu.RUnlock() + result, err = conn.ExecContext(context.Background(), query, args...) + } + + if err != nil { + return 0, err + } + + // 获取影响行数 + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, err + } + + if verbose { + log.Debugf("影响行数: %d", rowsAffected) + } + + return rowsAffected, nil + } +} + +// FetchOne 获取一行 - 对应Rocks的fetchone() +// 返回map[string]interface{}格式,key为列名 +func (db *DB) FetchOne() (map[string]interface{}, error) { + db.mu.RLock() + results := db.results + db.mu.RUnlock() + + if results == nil { + return nil, nil + } + + if !results.Next() { + return nil, nil + } + + columns, err := results.Columns() + if err != nil { + return nil, err + } + + values := make([]interface{}, len(columns)) + scanArgs := make([]interface{}, len(columns)) + for i := range values { + scanArgs[i] = &values[i] + } + + err = results.Scan(scanArgs...) + if err != nil { + return nil, err + } + + row := make(map[string]interface{}) + for i, col := range columns { + val := values[i] + if b, ok := val.([]byte); ok { + row[col] = string(b) + } else { + row[col] = val + } + } + + return row, nil +} + +// FetchAll 获取所有行 - 对应Rocks的fetchall() +// 返回[]map[string]interface{}格式 +func (db *DB) FetchAll() ([]map[string]interface{}, error) { + db.mu.RLock() + results := db.results + db.mu.RUnlock() + + if results == nil { + return nil, nil + } + + columns, err := results.Columns() + if err != nil { + return nil, err + } + + var rows []map[string]interface{} + + for results.Next() { + values := make([]interface{}, len(columns)) + scanArgs := make([]interface{}, len(columns)) + for i := range values { + scanArgs[i] = &values[i] + } + + err = results.Scan(scanArgs...) + if err != nil { + return nil, err + } + + row := make(map[string]interface{}) + for i, col := range columns { + val := values[i] + if b, ok := val.([]byte); ok { + row[col] = string(b) + } else { + row[col] = val + } + } + + rows = append(rows, row) + } + + return rows, nil +} + +// ==================== 连接维护 ==================== + +// RenewConnection 续期连接 +func (db *DB) RenewConnection() error { + db.mu.Lock() + defer db.mu.Unlock() + + if db.conn != nil { + db.conn.Close() + } + + conn, err := db.engine.Conn(context.Background()) + if err != nil { + return err + } + db.conn = conn + return nil +} + +// Close 关闭连接 +func (db *DB) Close() error { + db.mu.Lock() + defer db.mu.Unlock() + + if db.results != nil { + db.results.Close() + db.results = nil + } + if db.conn != nil { + db.conn.Close() + db.conn = nil + } + if db.engine != nil { + return db.engine.Close() + } + return nil +} + +// CloseConnection 只关闭当前连接,不关闭连接池 +func (db *DB) CloseConnection() error { + db.mu.Lock() + defer db.mu.Unlock() + + if db.results != nil { + db.results.Close() + db.results = nil + } + if db.conn != nil { + db.conn.Close() + db.conn = nil + } + return nil +} + +// ==================== 单例模式 ==================== + +var ( + instanceConfigured bool + instanceDBPath string + instanceDBName string +) diff --git a/internal/db/helper.go b/internal/db/helper.go deleted file mode 100644 index b6144e8..0000000 --- a/internal/db/helper.go +++ /dev/null @@ -1,624 +0,0 @@ -package db - -import ( - "fmt" - "net" - "os" - "strings" - "sync" -) - -/* - // 获取数据库实例 - database, err := db.GetInstance() - if err != nil { - log.Fatal(err) - } - defer database.Close() - - // 创建Helper - helper, _ := db.NewDBHelper() - - // 执行查询 - helper.Execute("SELECT * FROM nodes WHERE rack = ?", 1) - - // 获取一行 - row, _ := helper.FetchOne() - if row != nil { - log.Infof("节点: %v", row["name"]) - } - - // 获取所有行 - rows, _ := helper.FetchAll() - log.Infof("共 %d 个节点", len(rows)) - - // 使用Helper方法 - hostname, _ := helper.GetHostname("192.168.1.1") - log.Infof("解析主机名: %s", hostname) - - // 设置属性 - helper.SetCategoryAttr("global", "global", "Kickstart_PrivateHostname", "sunhpc-master") - - // 获取属性 - val := helper.GetCategoryAttr("global", "global", "Kickstart_PrivateHostname") - log.Infof("前端主机名: %s", val) -*/ - -const ( - attrPostfix = "_old" -) - -// DBHelper DatabaseHelper类 - 继承DB,扩展业务方法 -type DBHelper struct { - *DB - appliancesList []string - frontendName string - cacheAttrs sync.Map -} - -func NewDBHelper() (*DBHelper, error) { - db, err := GetInstance() - if err != nil { - return nil, err - } - return &DBHelper{ - DB: db, - appliancesList: nil, - frontendName: "", - }, nil -} - -// ==================== 节点查询 ==================== - -// GetListHostnames 获取所有主机名列表 -func (h *DBHelper) GetListHostnames() ([]string, error) { - _, err := h.Execute("SELECT name FROM nodes ORDER BY name") - if err != nil { - return nil, err - } - - rows, err := h.FetchAll() - if err != nil { - return nil, err - } - - var names []string - for _, row := range rows { - if name, ok := row["name"]; ok { - names = append(names, name.(string)) - } - } - return names, nil -} - -// GetNodesFromNames 从名称列表获取节点 -func (h *DBHelper) GetNodesFromNames(names []string, managedOnly bool) ([]map[string]interface{}, error) { - // 如果没有提供名称,返回所有节点 - if len(names) == 0 { - query := "SELECT * FROM nodes" - if managedOnly { - query = ` - SELECT n.* FROM nodes n - JOIN node_attrs a ON n.id = a.node_id - WHERE a.attr = 'managed' AND a.value = 'true' - ` - } - - _, err := h.Execute(query) - if err != nil { - return nil, err - } - return h.FetchAll() - } - - // 构建查询条件 - conditions := []string{} - args := []interface{}{} - - for _, name := range names { - if strings.HasPrefix(name, "select ") { - conditions = append(conditions, fmt.Sprintf("name IN (%s)", name[7:])) - - } else if strings.Contains(name, "%") { - conditions = append(conditions, "name LIKE ?") - args = append(args, name) - - } else if strings.HasPrefix(name, "rack") { - rackNum := strings.TrimPrefix(name, "rack") - conditions = append(conditions, "rack = ?") - args = append(args, rackNum) - - } else if h.IsApplianceName(name) { - conditions = append(conditions, `id IN ( - SELECT node_id FROM node_attrs - WHERE attr = 'appliance' AND value = ? - )`) - args = append(args, name) - - } else { - hostname, err := h.GetHostname(name) - if err == nil { - conditions = append(conditions, "name = ?") - args = append(args, hostname) - } - } - } - - if len(conditions) == 0 { - return []map[string]interface{}{}, nil - } - - query := "SELECT * FROM nodes WHERE " + strings.Join(conditions, " OR ") - _, err := h.Execute(query, args...) - if err != nil { - return nil, err - } - - nodes, err := h.FetchAll() - if err != nil { - return nil, err - } - - // 过滤受管节点 - if managedOnly { - var managed []map[string]interface{} - for _, node := range nodes { - val := h.GetHostAttr(node["name"].(string), "managed") - if val == "true" { - managed = append(managed, node) - } - } - return managed, nil - } - - return nodes, nil -} - -// ==================== 设备类型 ==================== - -// GetAppliancesListText 获取所有设备类型名称 -func (h *DBHelper) GetAppliancesListText() []string { - if h.appliancesList != nil { - return h.appliancesList - } - - _, err := h.Execute("SELECT DISTINCT value FROM node_attrs WHERE attr = 'appliance'") - if err != nil { - return []string{} - } - - rows, _ := h.FetchAll() - var apps []string - for _, row := range rows { - if val, ok := row["value"]; ok { - apps = append(apps, val.(string)) - } - } - - h.appliancesList = apps - return apps -} - -// IsApplianceName 检查是否为设备类型名称 -func (h *DBHelper) IsApplianceName(name string) bool { - for _, app := range h.GetAppliancesListText() { - if app == name { - return true - } - } - return false -} - -// ==================== 主机名解析 ==================== - -// GetHostname 规范化主机名 - 完全参考Rocks实现 -func (h *DBHelper) GetHostname(hostname string) (string, error) { - // 如果hostname为空,使用系统主机名 - if hostname == "" { - hostname, _ = os.Hostname() - hostname = strings.Split(hostname, ".")[0] - return h.GetHostname(hostname) - } - - // 1. 直接在nodes表中查找 - _, err := h.Execute("SELECT * FROM nodes WHERE name = ?", hostname) - if err == nil { - row, _ := h.FetchOne() - if row != nil { - return hostname, nil - } - } - - // 2. 尝试IP地址反向解析 - addr := net.ParseIP(hostname) - if addr != nil { - names, err := net.LookupAddr(hostname) - if err == nil && len(names) > 0 { - return h.GetHostname(strings.Split(names[0], ".")[0]) - } - } - - // 3. 在networks表中查找IP - if addr != nil { - _, err := h.Execute(` - SELECT n.name FROM nodes n - JOIN networks net ON n.id = net.node_id - WHERE net.ip = ? - `, addr.String()) - if err == nil { - row, _ := h.FetchOne() - if row != nil { - return row["name"].(string), nil - } - } - } - - // 4. 尝试MAC地址 - mac := strings.ReplaceAll(hostname, "-", ":") - _, err = h.Execute(` - SELECT n.name FROM nodes n - JOIN networks net ON n.id = net.node_id - WHERE net.mac = ? - `, mac) - if err == nil { - row, _ := h.FetchOne() - if row != nil { - return row["name"].(string), nil - } - } - - // 5. 检查别名 - _, err = h.Execute(` - SELECT n.name FROM nodes n - JOIN aliases a ON n.id = a.node_id - WHERE a.name = ? - `, hostname) - if err == nil { - row, _ := h.FetchOne() - if row != nil { - return row["name"].(string), nil - } - } - - // 6. 尝试FQDN - if strings.Contains(hostname, ".") { - parts := strings.Split(hostname, ".") - name := parts[0] - domain := strings.Join(parts[1:], ".") - - _, err := h.Execute(` - SELECT n.name FROM nodes n - JOIN networks net ON n.id = net.node_id - JOIN subnets s ON net.subnet_id = s.id - WHERE s.dns_zone = ? AND (net.name = ? OR n.name = ?) - `, domain, name, name) - if err == nil { - row, _ := h.FetchOne() - if row != nil { - return row["name"].(string), nil - } - } - } - - // 7. 如果以上都失败,抛出异常 - return "", fmt.Errorf("无法解析主机名: %s", hostname) -} - -// CheckHostnameValidity 检查主机名有效性 -func (h *DBHelper) CheckHostnameValidity(hostname string) error { - // 不能包含点 - if strings.Contains(hostname, ".") { - return fmt.Errorf("主机名 %s 不能包含点号", hostname) - } - - // 不能是rack<数字>格式 - if strings.HasPrefix(hostname, "rack") { - num := strings.TrimPrefix(hostname, "rack") - if _, err := fmt.Sscanf(num, "%d", new(int)); err == nil { - return fmt.Errorf("主机名 %s 不能是rack<数字>格式", hostname) - } - } - - // 不能是设备类型名称 - if h.IsApplianceName(hostname) { - return fmt.Errorf("主机名 %s 不能与设备类型名称相同", hostname) - } - - // 检查是否已存在 - _, err := h.GetHostname(hostname) - if err == nil { - return fmt.Errorf("节点 %s 已存在", hostname) - } - - return nil -} - -// ==================== 前端节点 ==================== - -// GetFrontendName 获取前端节点名称 -func (h *DBHelper) GetFrontendName() string { - if h.frontendName != "" { - return h.frontendName - } - - name := h.GetCategoryAttr("global", "global", "Kickstart_PrivateHostname") - if name != "" { - h.frontendName = name - } - return h.frontendName -} - -// ==================== 属性管理 ==================== - -// GetCategoryIndex 获取类别索引 -func (h *DBHelper) GetCategoryIndex(categoryName, categoryIndex string) (map[string]interface{}, map[string]interface{}, error) { - // 查询类别和索引 - _, err := h.Execute(` - SELECT c.id as cid, c.name as cname, i.id as iid, i.name as iname - FROM categories c - JOIN catindexes i ON c.id = i.category_id - WHERE c.name = ? AND i.name = ? - `, categoryName, categoryIndex) - - if err == nil { - row, _ := h.FetchOne() - if row != nil { - category := map[string]interface{}{ - "id": row["cid"], - "name": row["cname"], - } - catindex := map[string]interface{}{ - "id": row["iid"], - "name": row["iname"], - "category_id": row["cid"], - } - return category, catindex, nil - } - } - - // 不存在则创建 - // 创建类别 - _, err = h.Execute("INSERT INTO categories (name) VALUES (?)", categoryName) - if err != nil { - return nil, nil, err - } - - var catID int64 - h.Execute("SELECT last_insert_rowid()") - row, _ := h.FetchOne() - if row != nil { - catID = row["last_insert_rowid()"].(int64) - } - - // 创建索引 - _, err = h.Execute( - "INSERT INTO catindexes (name, category_id) VALUES (?, ?)", - categoryIndex, catID, - ) - if err != nil { - return nil, nil, err - } - - h.Execute("SELECT last_insert_rowid()") - row, _ = h.FetchOne() - var idxID int64 - if row != nil { - idxID = row["last_insert_rowid()"].(int64) - } - - category := map[string]interface{}{ - "id": catID, - "name": categoryName, - } - catindex := map[string]interface{}{ - "id": idxID, - "name": categoryIndex, - "category_id": catID, - } - - return category, catindex, nil -} - -// SetCategoryAttr 设置类别属性 -func (h *DBHelper) SetCategoryAttr(categoryName, catindexName, attr, value string) error { - cat, catindex, err := h.GetCategoryIndex(categoryName, catindexName) - if err != nil { - return err - } - - // 查询现有属性 - _, err = h.Execute(` - SELECT id, value FROM attributes - WHERE attr = ? AND category_id = ? AND catindex_id = ? - `, attr, cat["id"], catindex["id"]) - - if err == nil { - row, _ := h.FetchOne() - if row != nil { - // 更新现有属性 - oldValue := row["value"] - attrID := row["id"] - - _, err = h.Execute( - "UPDATE attributes SET value = ? WHERE id = ?", - value, attrID, - ) - if err != nil { - return err - } - - // 保存旧值 - if !strings.HasSuffix(attr, attrPostfix) { - h.SetCategoryAttr(categoryName, catindexName, attr+attrPostfix, oldValue.(string)) - } - return nil - } - } - - // 创建新属性 - _, err = h.Execute(` - INSERT INTO attributes (attr, value, category_id, catindex_id) - VALUES (?, ?, ?, ?) - `, attr, value, cat["id"], catindex["id"]) - - return err -} - -// GetCategoryAttr 获取类别属性 -func (h *DBHelper) GetCategoryAttr(categoryName, catindexName, attrName string) string { - cat, catindex, err := h.GetCategoryIndex(categoryName, catindexName) - if err != nil { - return "" - } - - _, err = h.Execute(` - SELECT value FROM attributes - WHERE attr = ? AND category_id = ? AND catindex_id = ? - `, attrName, cat["id"], catindex["id"]) - - if err != nil { - return "" - } - - row, _ := h.FetchOne() - if row == nil { - return "" - } - - return row["value"].(string) -} - -// RemoveCategoryAttr 移除类别属性 -func (h *DBHelper) RemoveCategoryAttr(categoryName, catindexName, attrName string) error { - cat, catindex, err := h.GetCategoryIndex(categoryName, catindexName) - if err != nil { - return err - } - - _, err = h.Execute(` - DELETE FROM attributes - WHERE attr = ? AND category_id = ? AND catindex_id = ? - `, attrName, cat["id"], catindex["id"]) - - if err != nil { - return err - } - - // 同时删除对应的_old属性 - _, _ = h.Execute(` - DELETE FROM attributes - WHERE attr = ? AND category_id = ? AND catindex_id = ? - `, attrName+attrPostfix, cat["id"], catindex["id"]) - - return nil -} - -// ==================== 主机属性 ==================== - -// GetHostAttr 获取主机属性 -func (h *DBHelper) GetHostAttr(hostname, attr string) string { - // 先从节点直接属性查询 - _, err := h.Execute(` - SELECT value FROM node_attrs - WHERE node_id = (SELECT id FROM nodes WHERE name = ?) - AND attr = ? - `, hostname, attr) - - if err == nil { - row, _ := h.FetchOne() - if row != nil { - return row["value"].(string) - } - } - - // 使用Rocks的属性解析链查询 - query := ` - SELECT a.value FROM attributes a - JOIN resolvechain r ON a.category_id = r.category_id - JOIN hostselections h ON a.category_id = h.category_id - AND a.catindex_id = h.selection - WHERE h.host_id = (SELECT id FROM nodes WHERE name = ?) - AND a.attr = ? - ORDER BY r.precedence DESC - LIMIT 1 - ` - - _, err = h.Execute(query, hostname, attr) - if err != nil { - return "" - } - - row, _ := h.FetchOne() - if row == nil { - return "" - } - - return row["value"].(string) -} - -// GetHostAttrs 获取主机所有属性 -func (h *DBHelper) GetHostAttrs(hostname string, showSource bool) map[string]interface{} { - attrs := make(map[string]interface{}) - - // 获取节点基本信息 - _, err := h.Execute(` - SELECT n.id, n.name, n.rack, n.rank, m.name as membership, a.name as appliance - FROM nodes n - LEFT JOIN memberships m ON n.membership_id = m.id - LEFT JOIN appliances a ON m.appliance_id = a.id - WHERE n.name = ? - `, hostname) - - if err == nil { - row, _ := h.FetchOne() - if row != nil { - if showSource { - attrs["hostname"] = []interface{}{row["name"], "I"} - attrs["rack"] = []interface{}{row["rack"], "I"} - attrs["rank"] = []interface{}{row["rank"], "I"} - attrs["appliance"] = []interface{}{row["appliance"], "I"} - attrs["membership"] = []interface{}{row["membership"], "I"} - } else { - attrs["hostname"] = row["name"] - attrs["rack"] = row["rack"] - attrs["rank"] = row["rank"] - attrs["appliance"] = row["appliance"] - attrs["membership"] = row["membership"] - } - } - } - - // 获取所有属性 - query := ` - SELECT a.attr, a.value, - CASE - WHEN h.host_id IS NOT NULL THEN 'H' - ELSE UPPER(SUBSTR(c.name, 1, 1)) - END as source - FROM attributes a - JOIN categories c ON a.category_id = c.id - LEFT JOIN hostselections h ON a.category_id = h.category_id - AND a.catindex_id = h.selection - AND h.host_id = (SELECT id FROM nodes WHERE name = ?) - UNION - SELECT attr, value, 'N' as source - FROM node_attrs - WHERE node_id = (SELECT id FROM nodes WHERE name = ?) - ` - - _, err = h.Execute(query, hostname, hostname) - if err == nil { - rows, _ := h.FetchAll() - for _, row := range rows { - attr := row["attr"].(string) - value := row["value"] - if showSource { - attrs[attr] = []interface{}{value, row["source"]} - } else { - attrs[attr] = value - } - } - } - - return attrs -} diff --git a/internal/db/schema.go b/internal/db/schema.go new file mode 100644 index 0000000..6d79321 --- /dev/null +++ b/internal/db/schema.go @@ -0,0 +1,84 @@ +// Package db defines the database schema. +package db + +// CurrentSchemaVersion returns the current schema version (for migrations) +func CurrentSchemaVersion() int { + return 1 +} + +// CreateTableStatements returns a list of CREATE TABLE statements. +func CreateTableStatements() []string { + return []string{ + createNodesTable(), + createAttributesTable(), + createNetworksTable(), + createSubnetsTable(), + createSoftwareTable(), + } +} + +// DropTableOrder returns table names in reverse dependency order for safe DROP. +func DropTableOrder() []string { + return []string{"software", "attributes", "nodes", "subnets", "networks"} +} + +// --- Private DDL Functions --- + +func createNodesTable() string { + return ` +CREATE TABLE IF NOT EXISTS nodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + hostname TEXT NOT NULL UNIQUE, + ip TEXT, + status TEXT DEFAULT 'active', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +);` +} + +func createAttributesTable() string { + return ` +CREATE TABLE IF NOT EXISTS attributes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + node_id INTEGER NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + FOREIGN KEY(node_id) REFERENCES nodes(id) ON DELETE CASCADE, + UNIQUE(node_id, key) +);` +} + +func createNetworksTable() string { + return ` +CREATE TABLE IF NOT EXISTS networks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + description TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP +);` +} + +func createSubnetsTable() string { + return ` +CREATE TABLE IF NOT EXISTS subnets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + network_id INTEGER NOT NULL, + cidr TEXT NOT NULL, + gateway TEXT, + vlan INTEGER, + FOREIGN KEY(network_id) REFERENCES networks(id) ON DELETE CASCADE, + UNIQUE(network_id, cidr) +);` +} + +func createSoftwareTable() string { + return ` +CREATE TABLE IF NOT EXISTS software ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + version TEXT NOT NULL, + installed_on INTEGER, + FOREIGN KEY(installed_on) REFERENCES nodes(id) ON DELETE SET NULL, + UNIQUE(name, version, installed_on) +);` +} diff --git a/internal/model/db.go b/internal/model/db.go new file mode 100644 index 0000000..21c82ed --- /dev/null +++ b/internal/model/db.go @@ -0,0 +1,5 @@ +package model + +type DBConfig struct { + ForceDB bool +} diff --git a/internal/service/manager.go b/internal/service/manager.go deleted file mode 100644 index 9e6f8c6..0000000 --- a/internal/service/manager.go +++ /dev/null @@ -1,25 +0,0 @@ -package service - -import ( - "fmt" - "sunhpc/internal/config" - "sunhpc/internal/log" - "sunhpc/internal/template" -) - -func Deploy(cfg *config.ServicesConfig) error { - // 示例:使用模板部署 DHCPD - if cfg.DHCPD.Enabled { - log.Info("部署 DHCPD 服务...") - // 从模板渲染配置文件 - err := template.RenderAndExecute("dhcpd.conf.tmpl", map[string]interface{}{ - "Subnet": "192.168.1.0", - "Netmask": "255.255.255.0", - }) - if err != nil { - return fmt.Errorf("DHCPD 配置失败: %v", err) - } - // 实际部署逻辑(启动服务等)... - } - return nil -} diff --git a/internal/system/hostname.go b/internal/system/hostname.go deleted file mode 100644 index bb6f8ef..0000000 --- a/internal/system/hostname.go +++ /dev/null @@ -1,97 +0,0 @@ -package system - -import ( - "fmt" - "os" - "os/exec" - "strings" -) - -// SetHostname 设置系统主机名 -// 参数: hostname - 目标主机名 -// 返回: error - 如果设置失败返回错误信息 -func SetHostname(hostname string) error { - if hostname == "" { - return nil // 空值跳过,不报错 - } - - // 检查是否已有相同主机名 - current, err := os.Hostname() - if err == nil && current == hostname { - return nil // 已经设置正确,无需修改 - } - - // 使用 hostnamectl 设置主机名(适用于 systemd 系统) - if _, err := exec.LookPath("hostnamectl"); err == nil { - cmd := exec.Command("hostnamectl", "set-hostname", hostname) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("hostnamectl 设置主机名失败: %v", err) - } - } else { - // 传统方法:直接修改 /etc/hostname - if err := os.WriteFile("/etc/hostname", []byte(hostname+"\n"), 0644); err != nil { - return fmt.Errorf("写入 /etc/hostname 失败: %v", err) - } - - // 立即生效(需要内核支持) - cmd := exec.Command("sysctl", "kernel.hostname="+hostname) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - // 不返回错误,因为重启后会生效 - fmt.Printf("警告: 主机名将在重启后完全生效\n") - } - } - - // 更新 /etc/hosts,确保本机解析正确 - if err := updateHostsFile(hostname); err != nil { - fmt.Printf("警告: 更新 /etc/hosts 失败: %v\n", err) - } - - return nil -} - -// updateHostsFile 更新 /etc/hosts 文件中的本机映射 -func updateHostsFile(hostname string) error { - content, err := os.ReadFile("/etc/hosts") - if err != nil { - return err - } - - lines := strings.Split(string(content), "\n") - newLines := make([]string, 0, len(lines)) - hostnameSet := false - - for _, line := range lines { - // 跳过空行和注释 - if line == "" || strings.HasPrefix(line, "#") { - newLines = append(newLines, line) - continue - } - - fields := strings.Fields(line) - if len(fields) >= 2 && fields[0] == "127.0.1.1" { - // 替换 Ubuntu/Debian 风格的本地主机名 - newLines = append(newLines, "127.0.1.1\t"+hostname) - hostnameSet = true - } else if len(fields) >= 2 && fields[0] == "127.0.0.1" { - // 保留原行,但检查是否包含主机名 - if !strings.Contains(line, hostname) { - line = line + " " + hostname - } - newLines = append(newLines, line) - hostnameSet = true - } else { - newLines = append(newLines, line) - } - } - - // 如果没有找到合适的位置,添加一行 - if !hostnameSet { - newLines = append(newLines, "127.0.1.1\t"+hostname) - } - - return os.WriteFile("/etc/hosts", []byte(strings.Join(newLines, "\n")), 0644) -} diff --git a/internal/system/motd.go b/internal/system/motd.go deleted file mode 100644 index 2b4c404..0000000 --- a/internal/system/motd.go +++ /dev/null @@ -1,47 +0,0 @@ -package system - -import ( - "os" - "time" -) - -// SetMOTD 设置 /etc/motd 文件内容 -// 参数: content - MOTD 文本内容 -// 返回: error - 写入文件错误 -func SetMOTD(content string) error { - if content == "" { - // 如果内容为空,不清除现有 MOTD,避免误操作 - return nil - } - - // 添加时间和系统信息 - finalContent := "========================================\n" - finalContent += "SunHPC 集群管理系统\n" - finalContent += "时间: " + time.Now().Format("2006-01-02 15:04:05") + "\n" - finalContent += "========================================\n\n" - finalContent += content - - // 确保行尾有换行 - if content[len(content)-1] != '\n' { - finalContent += "\n" - } - - return os.WriteFile("/etc/motd", []byte(finalContent), 0644) -} - -// ClearMOTD 清空 MOTD -func ClearMOTD() error { - return os.WriteFile("/etc/motd", []byte{}, 0644) -} - -// AppendToMOTD 追加内容到 MOTD -func AppendToMOTD(additional string) error { - f, err := os.OpenFile("/etc/motd", os.O_APPEND|os.O_WRONLY, 0644) - if err != nil { - return err - } - defer f.Close() - - _, err = f.WriteString(additional + "\n") - return err -} diff --git a/internal/system/selinux.go b/internal/system/selinux.go deleted file mode 100644 index 80254bb..0000000 --- a/internal/system/selinux.go +++ /dev/null @@ -1,97 +0,0 @@ -package system - -import ( - "fmt" - "os" - "os/exec" - "strings" -) - -// ConfigureSELinux 设置 SELinux 模式 -// 参数: mode - enforcing, permissive, disabled -// 返回: error - 配置错误 -func ConfigureSELinux(mode string) error { - if mode == "" { - return nil - } - - // 验证输入 - mode = strings.ToLower(strings.TrimSpace(mode)) - validModes := map[string]bool{ - "enforcing": true, - "permissive": true, - "disabled": true, - } - - if !validModes[mode] { - return fmt.Errorf("无效的 SELinux 模式: %s (可选: enforcing, permissive, disabled)", mode) - } - - // 检查 SELinux 是否可用 - if _, err := os.Stat("/selinux/enforce"); os.IsNotExist(err) { - if _, err := os.Stat("/sys/fs/selinux/enforce"); os.IsNotExist(err) { - return fmt.Errorf("系统不支持 SELinux 或未启用") - } - } - - // 临时生效 - if mode != "disabled" { // disabled 需要重启才能完全生效 - cmd := exec.Command("setenforce", mode) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("setenforce %s 失败: %v", mode, err) - } - } - - // 持久化配置 - return persistSELinuxMode(mode) -} - -// persistSELinuxMode 修改 /etc/selinux/config 实现持久化 -func persistSELinuxMode(mode string) error { - const selinuxConfig = "/etc/selinux/config" - - // 读取配置文件 - content, err := os.ReadFile(selinuxConfig) - if err != nil { - // 如果文件不存在,创建默认配置 - if os.IsNotExist(err) { - defaultConfig := fmt.Sprintf(`# This file controls the state of SELinux on the system. -# SELINUX= can take one of these three values: -# enforcing - SELinux security policy is enforced. -# permissive - SELinux prints warnings instead of enforcing. -# disabled - No SELinux policy is loaded. -SELINUX=%s -# SELINUXTYPE= can take one of three two values: -# targeted - Targeted processes are protected, -# minimum - Modification of targeted policy. Only selected processes are protected. -# mls - Multi Level Security protection. -SELINUXTYPE=targeted -`, mode) - return os.WriteFile(selinuxConfig, []byte(defaultConfig), 0644) - } - return err - } - - // 替换 SELINUX= 行 - lines := strings.Split(string(content), "\n") - for i, line := range lines { - if strings.HasPrefix(line, "SELINUX=") { - lines[i] = fmt.Sprintf("SELINUX=%s", mode) - break - } - } - - return os.WriteFile(selinuxConfig, []byte(strings.Join(lines, "\n")), 0644) -} - -// GetSELinuxMode 获取当前 SELinux 模式 -func GetSELinuxMode() (string, error) { - cmd := exec.Command("getenforce") - output, err := cmd.Output() - if err != nil { - return "", err - } - return strings.ToLower(strings.TrimSpace(string(output))), nil -} diff --git a/internal/system/ssh.go b/internal/system/ssh.go deleted file mode 100644 index 4984b98..0000000 --- a/internal/system/ssh.go +++ /dev/null @@ -1,187 +0,0 @@ -package system - -import ( - "fmt" - "os" - "os/exec" - "strings" - - "sunhpc/internal/config" -) - -// ConfigureSSH 配置 SSH 服务 -// 参数: cfg - config.SSHConfig 结构体 -// 返回: error - 配置错误 -func ConfigureSSH(cfg config.SSHConfig) error { - const sshdConfig = "/etc/ssh/sshd_config" - - // 读取现有配置 - content, err := os.ReadFile(sshdConfig) - if err != nil { - return fmt.Errorf("读取 sshd_config 失败: %v", err) - } - - // 备份原始配置 - backupPath := sshdConfig + ".sunhpc.bak" - if _, err := os.Stat(backupPath); os.IsNotExist(err) { - if err := os.WriteFile(backupPath, content, 0644); err != nil { - fmt.Printf("警告: 无法创建备份文件 %s: %v\n", backupPath, err) - } - } - - // 解析和修改配置 - lines := strings.Split(string(content), "\n") - newLines := make([]string, 0, len(lines)) - configMap := make(map[string]bool) - - // 处理每一行 - for _, line := range lines { - trimmed := strings.TrimSpace(line) - if trimmed == "" || strings.HasPrefix(trimmed, "#") { - newLines = append(newLines, line) - continue - } - - parts := strings.Fields(trimmed) - if len(parts) < 2 { - newLines = append(newLines, line) - continue - } - - key := parts[0] - configMap[key] = true - - // 根据配置更新 - switch key { - case "PermitRootLogin": - if cfg.PermitRootLogin != "" { - newLines = append(newLines, fmt.Sprintf("PermitRootLogin %s", cfg.PermitRootLogin)) - } else { - newLines = append(newLines, line) - } - case "PasswordAuthentication": - if cfg.PasswordAuth != "" { - newLines = append(newLines, fmt.Sprintf("PasswordAuthentication %s", cfg.PasswordAuth)) - } else { - newLines = append(newLines, line) - } - default: - newLines = append(newLines, line) - } - } - - // 添加缺失的配置项 - if cfg.PermitRootLogin != "" && !configMap["PermitRootLogin"] { - newLines = append(newLines, fmt.Sprintf("PermitRootLogin %s", cfg.PermitRootLogin)) - } - if cfg.PasswordAuth != "" && !configMap["PasswordAuthentication"] { - newLines = append(newLines, fmt.Sprintf("PasswordAuthentication %s", cfg.PasswordAuth)) - } - - // 写入新配置 - newContent := strings.Join(newLines, "\n") - if err := os.WriteFile(sshdConfig, []byte(newContent), 0644); err != nil { - return fmt.Errorf("写入 sshd_config 失败: %v", err) - } - - // 测试配置语法 - if err := testSSHDConfig(); err != nil { - // 恢复备份 - if backup, err := os.ReadFile(backupPath); err == nil { - os.WriteFile(sshdConfig, backup, 0644) - } - return fmt.Errorf("SSH 配置语法错误: %v", err) - } - - // 重启 SSH 服务 - return restartSSHD() -} - -// testSSHDConfig 测试 sshd 配置语法 -func testSSHDConfig() error { - cmd := exec.Command("sshd", "-t") - cmd.Stderr = os.Stderr - return cmd.Run() -} - -// restartSSHD 重启 SSH 服务 -func restartSSHD() error { - // 尝试不同的服务管理器 - serviceMgrs := []struct { - name string - args []string - }{ - {"systemctl", []string{"restart", "sshd"}}, - {"systemctl", []string{"restart", "ssh"}}, - {"service", []string{"sshd", "restart"}}, - {"service", []string{"ssh", "restart"}}, - } - - for _, mgr := range serviceMgrs { - if _, err := exec.LookPath(mgr.name); err == nil { - cmd := exec.Command(mgr.name, mgr.args...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err == nil { - return nil - } - } - } - - return fmt.Errorf("无法重启 SSH 服务,请手动重启") -} - -// AddSSHKey 添加 SSH 公钥到指定用户 -func AddSSHKey(username, pubkey string) error { - // 获取用户主目录 - homeDir, err := getUserHomeDir(username) - if err != nil { - return err - } - - sshDir := homeDir + "/.ssh" - authKeys := sshDir + "/authorized_keys" - - // 创建 .ssh 目录 - if err := os.MkdirAll(sshDir, 0700); err != nil { - return err - } - - // 追加公钥 - f, err := os.OpenFile(authKeys, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) - if err != nil { - return err - } - defer f.Close() - - _, err = f.WriteString(pubkey + "\n") - if err != nil { - return err - } - - // 修改所有权 - return chownRecursive(sshDir, username) -} - -// getUserHomeDir 获取用户主目录 -func getUserHomeDir(username string) (string, error) { - cmd := exec.Command("getent", "passwd", username) - output, err := cmd.Output() - if err != nil { - return "", fmt.Errorf("用户 %s 不存在", username) - } - - parts := strings.Split(strings.TrimSpace(string(output)), ":") - if len(parts) >= 6 { - return parts[5], nil - } - return "", fmt.Errorf("无法获取用户 %s 的主目录", username) -} - -// chownRecursive 递归修改文件所有者 -func chownRecursive(path, username string) error { - cmd := exec.Command("chown", "-R", username+":"+username, path) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - return cmd.Run() -} diff --git a/internal/system/sysctl.go b/internal/system/sysctl.go deleted file mode 100644 index 5986e57..0000000 --- a/internal/system/sysctl.go +++ /dev/null @@ -1,87 +0,0 @@ -package system - -import ( - "bufio" - "fmt" - "os" - "os/exec" - "strings" - "time" -) - -// ConfigureSysctl 设置内核参数 -// 参数: params - 键值对映射,如 {"net.ipv4.ip_forward": "1"} -// 返回: error - 第一个失败的错误 -func ConfigureSysctl(params map[string]string) error { - if len(params) == 0 { - return nil - } - - // 首先应用临时配置 - for k, v := range params { - cmd := exec.Command("sysctl", "-w", fmt.Sprintf("%s=%s", k, v)) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("设置 sysctl %s=%s 失败: %v", k, v, err) - } - } - - // 持久化配置到 /etc/sysctl.conf 或 /etc/sysctl.d/ - return appendToSysctlConf(params) -} - -// appendToSysctlConf 将参数写入持久化配置文件 -func appendToSysctlConf(params map[string]string) error { - const sysctlFile = "/etc/sysctl.d/99-sunhpc.conf" - - // 读取现有配置 - existing := make(map[string]bool) - if data, err := os.ReadFile(sysctlFile); err == nil { - scanner := bufio.NewScanner(strings.NewReader(string(data))) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - parts := strings.SplitN(line, "=", 2) - if len(parts) == 2 { - existing[strings.TrimSpace(parts[0])] = true - } - } - } - - // 构建新内容 - var content strings.Builder - content.WriteString("# SunHPC 系统优化配置\n") - content.WriteString("# 生成时间: " + time.Now().Format(time.RFC3339) + "\n\n") - - for k, v := range params { - // 跳过已存在的配置(避免重复) - if existing[k] { - continue - } - content.WriteString(fmt.Sprintf("%s = %s\n", k, v)) - } - - // 如果有新配置才写入 - if content.Len() > 100 { - if err := os.WriteFile(sysctlFile, []byte(content.String()), 0644); err != nil { - return err - } - // 应用持久化配置 - return exec.Command("sysctl", "--system").Run() - } - - return nil -} - -// GetSysctl 获取当前内核参数值 -func GetSysctl(key string) (string, error) { - cmd := exec.Command("sysctl", "-n", key) - output, err := cmd.Output() - if err != nil { - return "", err - } - return strings.TrimSpace(string(output)), nil -} diff --git a/internal/system/system.go b/internal/system/system.go deleted file mode 100644 index 6164d3d..0000000 --- a/internal/system/system.go +++ /dev/null @@ -1,178 +0,0 @@ -package system - -import ( - "os" - "path/filepath" - "sunhpc/internal/config" - "sunhpc/internal/log" -) - -// Context 系统配置上下文,包含所有命令行参数 -type Context struct { - Force bool // 强制模式 - DryRun bool // 干运行模式 - Verbose bool // 详细输出 - Timeout int // 超时时间 - Backup string // 备份路径 - YesMode bool // 自动确认 -} - -// ApplyAll 应用所有系统配置 -func ApplyAll(cfg *config.SunHPCConfig) error { - log.Info("开始应用系统配置...") - - if err := SetHostnameWithContext(cfg.Hostname, nil); err != nil { - log.Warnf("设置主机名失败: %v", err) - } - - if err := SetMOTDWithContext(cfg.MOTD, nil); err != nil { - log.Warnf("设置 MOTD 失败: %v", err) - } - - if err := ConfigureSysctlWithContext(cfg.Sysctl, nil); err != nil { - log.Warnf("配置 sysctl 失败: %v", err) - } - - if err := ConfigureSELinuxWithContext(cfg.SELinux, nil); err != nil { - log.Warnf("配置 SELinux 失败: %v", err) - } - - if err := ConfigureSSHWithContext(cfg.SSH, nil); err != nil { - log.Warnf("配置 SSH 失败: %v", err) - } - - log.Info("系统配置应用完成") - return nil -} - -// SetHostnameWithContext 设置系统主机名,带上下文参数 -func SetHostnameWithContext(hostname string, ctx *Context) error { - if ctx != nil && ctx.DryRun { - log.Infof("[干运行] 设置主机名为: %s", hostname) - return nil - } - - if hostname == "" { - return nil - } - - // 检查是否需要强制设置 - current, _ := os.Hostname() - if current == hostname && (ctx == nil || !ctx.Force) { - log.Infof("主机名已是 '%s',跳过设置", hostname) - return nil - } - - log.Infof("设置主机名为: %s", hostname) - return SetHostname(hostname) -} - -// SetMOTDWithContext 设置 MOTD,带上下文参数 -func SetMOTDWithContext(content string, ctx *Context) error { - if ctx != nil && ctx.DryRun { - log.Info("[干运行] 设置 MOTD") - return nil - } - - if content == "" { - return nil - } - - // 备份现有文件 - if ctx != nil && ctx.Backup != "" { - backupMOTD(ctx.Backup) - } - - log.Info("更新 /etc/motd") - return SetMOTD(content) -} - -// ConfigureSysctlWithContext 配置内核参数,带上下文参数 -func ConfigureSysctlWithContext(params map[string]string, ctx *Context) error { - if ctx != nil && ctx.DryRun { - log.Info("[干运行] 配置 sysctl 参数") - return nil - } - - if len(params) == 0 { - return nil - } - - // 备份现有配置 - if ctx != nil && ctx.Backup != "" { - backupSysctl(ctx.Backup) - } - - return ConfigureSysctl(params) -} - -// ConfigureSELinuxWithContext 配置 SELinux,带上下文参数 -func ConfigureSELinuxWithContext(mode string, ctx *Context) error { - if ctx != nil && ctx.DryRun { - log.Infof("[干运行] 设置 SELinux 模式为: %s", mode) - return nil - } - - if mode == "" { - return nil - } - - // 检查当前模式 - current, _ := GetSELinuxMode() - if current == mode && (ctx == nil || !ctx.Force) { - log.Infof("SELinux 已是 '%s' 模式,跳过设置", mode) - return nil - } - - log.Infof("设置 SELinux 模式为: %s", mode) - return ConfigureSELinux(mode) -} - -// ConfigureSSHWithContext 配置 SSH,带上下文参数 -func ConfigureSSHWithContext(cfg config.SSHConfig, ctx *Context) error { - if ctx != nil && ctx.DryRun { - log.Info("[干运行] 配置 SSH 服务") - return nil - } - - // 备份配置文件 - if ctx != nil && ctx.Backup != "" { - backupSSHConfig(ctx.Backup) - } - - log.Info("配置 SSH 服务") - return ConfigureSSH(cfg) -} - -// 备份函数 -func backupMOTD(backupDir string) error { - backupPath := filepath.Join(backupDir, "motd."+filepath.Base(os.Args[0])+".bak") - if err := os.MkdirAll(backupDir, 0755); err != nil { - return err - } - return copyFile("/etc/motd", backupPath) -} - -func backupSysctl(backupDir string) error { - backupPath := filepath.Join(backupDir, "sysctl.conf.bak") - if err := os.MkdirAll(backupDir, 0755); err != nil { - return err - } - return copyFile("/etc/sysctl.conf", backupPath) -} - -func backupSSHConfig(backupDir string) error { - backupPath := filepath.Join(backupDir, "sshd_config.bak") - if err := os.MkdirAll(backupDir, 0755); err != nil { - return err - } - return copyFile("/etc/ssh/sshd_config", backupPath) -} - -func copyFile(src, dst string) error { - data, err := os.ReadFile(src) - if err != nil { - return err - } - return os.WriteFile(dst, data, 0644) -} diff --git a/internal/template/engine.go b/internal/template/engine.go deleted file mode 100644 index ca367e0..0000000 --- a/internal/template/engine.go +++ /dev/null @@ -1,61 +0,0 @@ -package template - -import ( - "bytes" - "fmt" - "os" - "os/exec" - "path/filepath" - "text/template" - - "sunhpc/internal/config" - "sunhpc/internal/log" -) - -// RenderAndExecute 从模板目录加载模板,渲染后生成临时脚本并执行 -// tmplName: 模板文件名(位于 /etc/sunhpc/tmpl.d/) -// data: 模板变量 -func RenderAndExecute(tmplName string, data interface{}) error { - tmplPath := filepath.Join(config.TmplDir, tmplName) - if _, err := os.Stat(tmplPath); err != nil { - return fmt.Errorf("模板文件不存在: %s", tmplPath) - } - - content, err := os.ReadFile(tmplPath) - if err != nil { - return err - } - - t, err := template.New(tmplName).Parse(string(content)) - if err != nil { - return err - } - - var buf bytes.Buffer - if err := t.Execute(&buf, data); err != nil { - return err - } - - // 生成临时脚本 - tmpFile, err := os.CreateTemp("/tmp", "sunhpc-*.sh") - if err != nil { - return err - } - defer os.Remove(tmpFile.Name()) - - if _, err := tmpFile.Write(buf.Bytes()); err != nil { - tmpFile.Close() - return err - } - tmpFile.Close() - - if err := os.Chmod(tmpFile.Name(), 0755); err != nil { - return err - } - - log.Infof("执行模板脚本: %s", tmpFile.Name()) - cmd := exec.Command("/bin/bash", tmpFile.Name()) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - return cmd.Run() -} diff --git a/internal/templating/embedded.go b/internal/templating/embedded.go new file mode 100644 index 0000000..fa5447c --- /dev/null +++ b/internal/templating/embedded.go @@ -0,0 +1,54 @@ +// internal/templating/embedded.go +package templating + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + + "sunhpc/tmpls" + + "gopkg.in/yaml.v3" +) + +// ListEmbeddedTemplates 返回所有内置模板名称(不含路径和扩展名) +func ListEmbeddedTemplates() ([]string, error) { + entries, err := fs.ReadDir(tmpls.FS, ".") + if err != nil { + return nil, err + } + var names []string + for _, entry := range entries { + if entry.IsDir() || filepath.Ext(entry.Name()) != ".yaml" { + continue + } + names = append(names, entry.Name()[:len(entry.Name())-5]) // 去掉 .yaml + } + return names, nil +} + +// LoadEmbeddedTemplate 从二进制加载内置模板 +func LoadEmbeddedTemplate(name string) (*Template, error) { + data, err := tmpls.FS.ReadFile(name + ".yaml") + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("内置模板 '%s' 不存在", name) + } + return nil, err + } + var tmpl Template + if err := yaml.Unmarshal(data, &tmpl); err != nil { + return nil, fmt.Errorf("解析内置模板失败: %w", err) + } + return &tmpl, nil +} + +// DumpEmbeddedTemplateToFile 将内置模板写入文件 +func DumpEmbeddedTemplateToFile(name, outputPath string) error { + data, err := tmpls.FS.ReadFile(name + ".yaml") + if err != nil { + return fmt.Errorf("找不到内置模板 '%s': %w", name, err) + } + return os.WriteFile(outputPath, data, 0644) +} diff --git a/internal/templating/engine.go b/internal/templating/engine.go new file mode 100644 index 0000000..9985c70 --- /dev/null +++ b/internal/templating/engine.go @@ -0,0 +1,104 @@ +// internal/templating/engine.go +package templating + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "text/template" + + "gopkg.in/yaml.v3" +) + +// LoadTemplate 从文件加载 YAML 模板 +func LoadTemplate(path string) (*Template, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("无法读取模板文件 %s: %w", path, err) + } + var tmpl Template + if err := yaml.Unmarshal(data, &tmpl); err != nil { + return nil, fmt.Errorf("YAML 解析失败: %w", err) + } + return &tmpl, nil +} + +// Render 渲染模板为具体操作 +func (t *Template) Render(ctx Context) (map[string][]RenderedStep, error) { + result := make(map[string][]RenderedStep) + + for stageName, steps := range t.Stages { + var renderedSteps []RenderedStep + for _, step := range steps { + // 处理 condition + if step.Condition != "" { + condTmpl, err := template.New("condition").Parse(step.Condition) + if err != nil { + return nil, fmt.Errorf("条件模板语法错误: %w", err) + } + var buf bytes.Buffer + if err := condTmpl.Execute(&buf, ctx); err != nil { + return nil, fmt.Errorf("执行条件模板失败: %w", err) + } + if buf.String() == "" { + continue // 条件不满足,跳过 + } + } + + // 渲染 content + contentTmpl, err := template.New("content").Parse(step.Content) + if err != nil { + return nil, fmt.Errorf("内容模板语法错误: %w", err) + } + var buf bytes.Buffer + if err := contentTmpl.Execute(&buf, ctx); err != nil { + return nil, fmt.Errorf("执行内容模板失败: %w", err) + } + + renderedSteps = append(renderedSteps, RenderedStep{ + Type: step.Type, + Path: step.Path, + Content: buf.String(), + }) + } + result[stageName] = renderedSteps + } + return result, nil +} + +// RenderedStep 是渲染后的步骤 +type RenderedStep struct { + Type string + Path string + Content string +} + +// WriteFiles 将 file 类型步骤写入磁盘 +func WriteFiles(steps []RenderedStep, rootDir string) error { + for _, s := range steps { + if s.Type != "file" { + continue + } + fullPath := s.Path + if !filepath.IsAbs(s.Path) { + fullPath = filepath.Join(rootDir, s.Path) + } + if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { + return err + } + if err := os.WriteFile(fullPath, []byte(s.Content), 0644); err != nil { + return fmt.Errorf("写入文件 %s 失败: %w", fullPath, err) + } + } + return nil +} + +// PrintScripts 打印 script 内容(安全起见,先不自动执行) +func PrintScripts(steps []RenderedStep) { + for _, s := range steps { + if s.Type == "script" { + fmt.Printf("# --- 脚本开始 ---\n%s\n# --- 脚本结束 ---\n", s.Content) + } + } +} diff --git a/internal/templating/types.go b/internal/templating/types.go new file mode 100644 index 0000000..13fc56e --- /dev/null +++ b/internal/templating/types.go @@ -0,0 +1,38 @@ +package templating + +// Template 是 YAML 模板的顶层结构 +type Template struct { + Description string `yaml:"description,omitempty"` + Copyright string `yaml:"copyright,omitempty"` + Stages map[string][]Step `yaml:"stages"` +} + +// Step 表示一个操作步骤 +type Step struct { + Type string `yaml:"type"` // "file" 或 "script" + Path string `yaml:"path,omitempty"` // 文件路径(仅 type=file) + Content string `yaml:"content"` // 多行内容 + Condition string `yaml:"condition,omitempty"` // 条件表达式(Go template) +} + +// Context 是渲染模板时的上下文数据 +type Context struct { + Node NodeInfo `json:"node"` + Cluster ClusterInfo `json:"cluster"` +} + +// NodeInfo 节点信息 +type NodeInfo struct { + Hostname string `json:"hostname"` + OldHostname string `json:"old_hostname,omitempty"` + Domain string `json:"domain"` + IP string `json:"ip"` +} + +// ClusterInfo 集群信息 +type ClusterInfo struct { + Name string `json:"name"` + Domain string `json:"domain"` + AdminEmail string `json:"admin_email"` + TimeZone string `json:"time_zone"` +} diff --git a/main.go b/main.go index 38d95a2..62c1aeb 100644 --- a/main.go +++ b/main.go @@ -10,37 +10,3 @@ func main() { os.Exit(1) } } - -/* - // 初始化日志(verbose=true 显示调试信息) - log.Init(true) - - // 基础用法 - log.Info("服务启动成功") - log.Infof("用户 %s 已添加", "testuser") - - log.Warn("磁盘使用率超过 80%") - log.Warnf("节点 %s 网络延迟过高", "node01") - - log.Error("配置文件解析失败") - log.Errorf("无法连接到数据库: %v", err) - - log.Debug("正在执行命令: ssh root@192.168.1.1") - log.Debugf("加载了 %d 个节点配置", len(nodes)) - - // 致命错误 - if err != nil { - log.Fatal("初始化失败: ", err) - } - - // 临时禁用颜色(例如输出重定向时) - if !isTerminal { - log.EnableColor(false) - } - - // 设置日志级别 - log.SetLevel(log.WarnLevel) // 只显示警告及以上级别 - - // 启用调用者信息 - log.EnableCaller(true) -*/ diff --git a/sunhpc b/sunhpc index bbb83fd..1dbbb28 100755 Binary files a/sunhpc and b/sunhpc differ diff --git a/sunhpc.yaml b/sunhpc.yaml new file mode 100644 index 0000000..755c3ca --- /dev/null +++ b/sunhpc.yaml @@ -0,0 +1,4 @@ +db: + type: sqlite + name: sunhpc.db + path: /tmp/sunhpc diff --git a/test_db.sh b/test_db.sh deleted file mode 100644 index 651bf5c..0000000 --- a/test_db.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash - -echo "=========================================" -echo "SunHPC 数据库测试脚本" -echo "=========================================" -echo "" - -# 1. 清理旧数据库 -echo "[1/4] 清理旧数据库..." -rm -f /var/lib/sunhpc/sunhpc.db* -rm -rf /var/lib/sunhpc/ -echo "✓ 清理完成" -echo "" - -# 2. 初始化数据库 -echo "[2/4] 初始化数据库..." -./sunhpc init database -v --force -if [ $? -ne 0 ]; then - echo "✗ 初始化失败" - exit 1 -fi -echo "✓ 初始化完成" -echo "" - -# 3. 添加节点 -echo "[3/4] 添加节点..." -./sunhpc node add node1 --cpus 32 --memory 128 --disk 1000 --os "CentOS 7.9" --kernel "3.10.0-1160.el7.x86_64" -if [ $? -ne 0 ]; then - echo "✗ 添加节点失败" - exit 1 -fi -echo "✓ 节点添加完成" -echo "" - -# 4. 查询节点 -echo "[4/4] 查询节点列表..." -./sunhpc node list -if [ $? -ne 0 ]; then - echo "✗ 查询失败" - exit 1 -fi -echo "" - -echo "=========================================" -echo "测试完成!" -echo "=========================================" diff --git a/tmpls/autofs.yaml b/tmpls/autofs.yaml new file mode 100644 index 0000000..4ab87e0 --- /dev/null +++ b/tmpls/autofs.yaml @@ -0,0 +1,29 @@ +description: AutoFS server for SunHPC clusters +copyright: | + Copyright (c) 2026 SunHPC Project. + Licensed under Apache 2.0. + +stages: + post: + - type: file + path: /etc/auto.master + content: | + /share /etc/auto.share --timeout=1200 + /home /etc/auto.home --timeout=1200 + + - type: file + path: /etc/auto.share + content: | + apps {{ .Node.Hostname }}.{{ .Cluster.Domain }}:/export/& + + - type: script + content: | + mkdir -p /export/apps + echo "AutoFS 配置已生成" + + configure: + - type: script + condition: "{{ if .Node.OldHostname }}true{{ end }}" + content: | + sed -i 's/{{ .Node.OldHostname }}/{{ .Node.Hostname }}/g' /etc/auto.share + systemctl restart autofs \ No newline at end of file diff --git a/tmpls/tmpls.go b/tmpls/tmpls.go new file mode 100644 index 0000000..44f769c --- /dev/null +++ b/tmpls/tmpls.go @@ -0,0 +1,9 @@ +package tmpls + +import ( + "embed" + _ "embed" +) + +//go:embed *.yaml +var FS embed.FS