This commit is contained in:
2026-02-14 05:36:00 +08:00
commit d7cd899983
37 changed files with 4169 additions and 0 deletions

14
internal/auth/auth.go Normal file
View File

@@ -0,0 +1,14 @@
package auth
import (
"fmt"
"os"
)
// RequireRoot 检查是否以 root 身份运行
func RequireRoot() error {
if os.Geteuid() != 0 {
return fmt.Errorf("此操作需要 root 权限,请使用 sudo 或切换到 root 用户")
}
return nil
}

69
internal/config/config.go Normal file
View File

@@ -0,0 +1,69 @@
package config
import (
"os"
"path/filepath"
"gopkg.in/yaml.v3"
)
const (
BaseDir = "/etc/sunhpc"
LogDir = "/var/log/sunhpc"
TmplDir = BaseDir + "/tmpl.d"
)
var (
SunHPCFile = filepath.Join(BaseDir, "sunhpc.yaml")
NodesFile = filepath.Join(BaseDir, "nodes.yaml")
NetworkFile = filepath.Join(BaseDir, "network.yaml")
DisksFile = filepath.Join(BaseDir, "disks.yaml")
ServicesFile = filepath.Join(BaseDir, "services.yaml")
FirewallFile = filepath.Join(BaseDir, "iptables.yaml")
)
// InitDirs 创建所有必需目录
func InitDirs() error {
dirs := []string{
BaseDir,
TmplDir,
LogDir,
}
for _, d := range dirs {
if err := os.MkdirAll(d, 0755); err != nil {
return err
}
}
return nil
}
// CreateDefaultConfigs 生成默认 YAML 配置文件
func CreateDefaultConfigs() error {
files := map[string]interface{}{
SunHPCFile: DefaultSunHPC(),
NodesFile: DefaultNodes(),
NetworkFile: DefaultNetwork(),
DisksFile: DefaultDisks(),
ServicesFile: DefaultServices(),
FirewallFile: DefaultFirewall(),
}
for path, data := range files {
if err := writeYAML(path, data); err != nil {
return err
}
}
return nil
}
func writeYAML(path string, data interface{}) error {
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
enc := yaml.NewEncoder(f)
defer enc.Close()
return enc.Encode(data)
}

128
internal/config/defaults.go Normal file
View File

@@ -0,0 +1,128 @@
package config
// SunHPC 主配置
type SunHPCConfig struct {
Hostname string `yaml:"hostname"`
MOTD string `yaml:"motd"`
Sysctl map[string]string `yaml:"sysctl"`
SELinux string `yaml:"selinux"` // enforcing, permissive, disabled
SSH SSHConfig `yaml:"ssh"`
}
type SSHConfig struct {
PermitRootLogin string `yaml:"permit_root_login"`
PasswordAuth string `yaml:"password_authentication"`
}
func DefaultSunHPC() *SunHPCConfig {
return &SunHPCConfig{
Hostname: "sunhpc-master",
MOTD: "Welcome to SunHPC Cluster\n",
Sysctl: map[string]string{
"net.ipv4.ip_forward": "1",
"vm.swappiness": "10",
},
SELinux: "enforcing",
SSH: SSHConfig{
PermitRootLogin: "yes",
PasswordAuth: "yes",
},
}
}
// Nodes 节点配置
type NodesConfig struct {
Nodes []Node `yaml:"nodes"`
}
type Node struct {
Hostname string `yaml:"hostname"`
MAC string `yaml:"mac"`
IP string `yaml:"ip"`
Role string `yaml:"role"` // master, compute, login
}
func DefaultNodes() *NodesConfig {
return &NodesConfig{
Nodes: []Node{
{Hostname: "master", MAC: "00:11:22:33:44:55", IP: "192.168.1.1", Role: "master"},
},
}
}
// Network 网络配置
type NetworkConfig struct {
Interface string `yaml:"interface"`
Subnet string `yaml:"subnet"`
Netmask string `yaml:"netmask"`
Gateway string `yaml:"gateway"`
DNSServers []string `yaml:"dns_servers"`
}
func DefaultNetwork() *NetworkConfig {
return &NetworkConfig{
Interface: "eth0",
Subnet: "192.168.1.0",
Netmask: "255.255.255.0",
Gateway: "192.168.1.1",
DNSServers: []string{"8.8.8.8", "114.114.114.114"},
}
}
// Disks 磁盘配置
type DisksConfig struct {
Disks []Disk `yaml:"disks"`
}
type Disk struct {
Device string `yaml:"device"`
Mount string `yaml:"mount"`
FSType string `yaml:"fstype"`
Options string `yaml:"options"`
}
func DefaultDisks() *DisksConfig {
return &DisksConfig{
Disks: []Disk{
{Device: "/dev/sda1", Mount: "/", FSType: "ext4", Options: "defaults"},
},
}
}
// Services 服务配置
type ServicesConfig struct {
HTTPD Service `yaml:"httpd"`
TFTPD Service `yaml:"tftpd"`
DHCPD Service `yaml:"dhcpd"`
}
type Service struct {
Enabled bool `yaml:"enabled"`
Config string `yaml:"config,omitempty"`
}
func DefaultServices() *ServicesConfig {
return &ServicesConfig{
HTTPD: Service{Enabled: true},
TFTPD: Service{Enabled: true},
DHCPD: Service{Enabled: true},
}
}
// Firewall 防火墙配置
type FirewallConfig struct {
DefaultPolicy string `yaml:"default_policy"`
Rules []string `yaml:"rules"`
}
func DefaultFirewall() *FirewallConfig {
return &FirewallConfig{
DefaultPolicy: "DROP",
Rules: []string{
"-A INPUT -m state --state ESTABLISHED,RELATED -j ACCEPT",
"-A INPUT -p icmp -j ACCEPT",
"-A INPUT -i lo -j ACCEPT",
"-A INPUT -p tcp --dport 22 -j ACCEPT",
},
}
}

View File

@@ -0,0 +1,43 @@
package config
import (
"os"
"gopkg.in/yaml.v3"
)
func LoadSunHPC() (*SunHPCConfig, error) {
return loadYAML[SunHPCConfig](SunHPCFile)
}
func LoadNodes() (*NodesConfig, error) {
return loadYAML[NodesConfig](NodesFile)
}
func LoadNetwork() (*NetworkConfig, error) {
return loadYAML[NetworkConfig](NetworkFile)
}
func LoadDisks() (*DisksConfig, error) {
return loadYAML[DisksConfig](DisksFile)
}
func LoadServices() (*ServicesConfig, error) {
return loadYAML[ServicesConfig](ServicesFile)
}
func LoadFirewall() (*FirewallConfig, error) {
return loadYAML[FirewallConfig](FirewallFile)
}
func loadYAML[T any](path string) (*T, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg T
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
return &cfg, nil
}

794
internal/db/db.go Normal file
View File

@@ -0,0 +1,794 @@
package db
import (
"context"
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"sync"
"time"
"sunhpc/internal/log"
_ "github.com/mattn/go-sqlite3"
)
// 全局单例
var (
globalDB *DB
once sync.Once
)
// DB 核心数据库类 - 对应Rocks的Database类
type DB struct {
// 连接参数
dbUser string
dbPasswd string
dbHost string
dbName string
dbPath string
dbSocket string
verbose bool
forceInit bool
// 连接对象
engine *sql.DB // 连接池
conn *sql.Conn // 当前连接
results *sql.Rows // 当前结果集
// 线程本地存储模拟
sessions sync.Map
mu sync.RWMutex
}
// NewDB 创建新实例
func NewDB() *DB {
return &DB{
dbUser: "",
dbPasswd: "",
dbHost: "localhost",
dbName: "sunhpc",
dbPath: "/var/lib/sunhpc",
dbSocket: "/var/lib/sunhpc/mysql/mysql.sock",
verbose: false,
}
}
// ==================== 连接参数设置/获取 ====================
func (db *DB) SetDBPasswd(passwd string) {
db.mu.Lock()
defer db.mu.Unlock()
db.dbPasswd = passwd
}
func (db *DB) GetDBPasswd() string {
db.mu.RLock()
if db.dbPasswd != "" {
db.mu.RUnlock()
return db.dbPasswd
}
db.mu.RUnlock()
db.mu.Lock()
defer db.mu.Unlock()
// 从配置文件读取密码
username := db.GetDBUsername()
var filename string
switch username {
case "root":
filename = "/root/.sunhpc.my.cnf"
default:
filename = fmt.Sprintf("/home/%s/.sunhpc.my.cnf", username)
}
data, err := ioutil.ReadFile(filename)
if err != nil {
return ""
}
lines := strings.Split(string(data), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
parts := strings.Split(line, "=")
if len(parts) == 2 && strings.TrimSpace(parts[0]) == "password" {
db.dbPasswd = strings.TrimSpace(parts[1])
break
}
}
return db.dbPasswd
}
func (db *DB) SetDBUsername(name string) {
db.mu.Lock()
defer db.mu.Unlock()
db.dbUser = name
}
func (db *DB) GetDBUsername() string {
db.mu.RLock()
if db.dbUser != "" {
db.mu.RUnlock()
return db.dbUser
}
db.mu.RUnlock()
db.mu.Lock()
defer db.mu.Unlock()
db.dbUser = os.Getenv("USER")
return db.dbUser
}
func (db *DB) SetDBHostname(host string) {
db.mu.Lock()
defer db.mu.Unlock()
db.dbHost = host
}
func (db *DB) GetDBHostname() string {
db.mu.RLock()
defer db.mu.RUnlock()
return db.dbHost
}
func (db *DB) SetDBName(name string) {
db.mu.Lock()
defer db.mu.Unlock()
db.dbName = name
}
func (db *DB) GetDBName() string {
db.mu.RLock()
defer db.mu.RUnlock()
return db.dbName
}
func (db *DB) SetDBPath(path string) {
db.mu.Lock()
defer db.mu.Unlock()
db.dbPath = path
}
func (db *DB) GetDBPath() string {
db.mu.RLock()
defer db.mu.RUnlock()
return db.dbPath
}
func (db *DB) SetVerbose(verbose bool) {
db.mu.Lock()
defer db.mu.Unlock()
db.verbose = verbose
}
func (db *DB) SetForceInit(force bool) {
db.mu.Lock()
defer db.mu.Unlock()
db.forceInit = force
}
// ==================== 连接管理 ====================
// Connect 连接数据库
func (db *DB) Connect() error {
log.Debug("连接数据库...")
db.mu.Lock()
defer db.mu.Unlock()
log.Debug("检查 SUNHPCDEBUG 环境变量...")
if os.Getenv("SUNHPCDEBUG") != "" {
db.verbose = true
}
// 使用SQLite
dbFullPath := filepath.Join(db.dbPath, db.dbName+".db")
log.Debugf("数据库路径: %s", dbFullPath)
// 确保目录存在
log.Debug("确保数据库目录存在...")
os.MkdirAll(db.dbPath, 0755)
engine, err := sql.Open("sqlite3", dbFullPath+"?_foreign_keys=on&_journal_mode=WAL")
log.Debugf("打开数据库连接...")
if err != nil {
return fmt.Errorf("打开数据库失败: %v", err)
}
engine.SetMaxOpenConns(10)
engine.SetMaxIdleConns(5)
engine.SetConnMaxLifetime(time.Hour)
db.engine = engine
conn, err := engine.Conn(context.Background())
log.Debugf("获取数据库连接...")
if err != nil {
return fmt.Errorf("获取连接失败: %v", err)
}
db.conn = conn
// 初始化数据库表
if err := db.initSchema(); err != nil {
return fmt.Errorf("初始化数据库表失败: %v", err)
}
if db.verbose {
log.Infof("数据库连接成功: %s", dbFullPath)
}
return nil
}
// initSchema 初始化数据库表结构 - 所有表定义在这里
func (db *DB) initSchema() error {
log.Debug("初始化数据库表结构...")
// 检查 nodes 表是否已存在
var tableName string
err := db.engine.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='nodes'").Scan(&tableName)
if err == nil && !db.forceInit {
log.Debug("数据库表已存在,跳过初始化")
return nil
}
if db.forceInit {
log.Warn("强制重新初始化数据库表结构...")
} else {
log.Info("首次初始化数据库表结构...")
}
// 如果强制初始化,先删除所有表
if db.forceInit {
log.Info("删除现有表...")
dropSQLs := []string{
`DROP TABLE IF EXISTS resolvechain;`,
`DROP TABLE IF EXISTS hostselections;`,
`DROP TABLE IF EXISTS attributes;`,
`DROP TABLE IF EXISTS catindexes;`,
`DROP TABLE IF EXISTS categories;`,
`DROP TABLE IF EXISTS node_attrs;`,
`DROP TABLE IF EXISTS aliases;`,
`DROP TABLE IF EXISTS networks;`,
`DROP TABLE IF EXISTS subnets;`,
`DROP TABLE IF EXISTS software_installs;`,
`DROP TABLE IF EXISTS memberships;`,
`DROP TABLE IF EXISTS appliances;`,
`DROP TABLE IF EXISTS nodes;`,
}
for _, sql := range dropSQLs {
if _, err := db.engine.Exec(sql); err != nil {
log.Warnf("删除表失败: %v", err)
}
}
log.Info("现有表已删除")
}
// 开启事务
tx, err := db.engine.Begin()
if err != nil {
return fmt.Errorf("开启事务失败: %v", err)
}
// 使用exec执行每条SQL单独执行
sqls := []string{
// 创建表 - 注意创建顺序(先创建主表,再创建有外键的表)
`CREATE TABLE IF NOT EXISTS nodes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
rack INTEGER DEFAULT 0,
rank INTEGER DEFAULT 0,
membership_id INTEGER,
cpus INTEGER DEFAULT 0,
memory INTEGER DEFAULT 0,
disk INTEGER DEFAULT 0,
os TEXT,
kernel TEXT,
last_state_change DATETIME DEFAULT CURRENT_TIMESTAMP,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);`,
`CREATE TABLE IF NOT EXISTS appliances (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
description TEXT,
node_type TEXT DEFAULT 'compute'
);`,
`CREATE TABLE IF NOT EXISTS memberships (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
appliance_id INTEGER,
FOREIGN KEY (appliance_id) REFERENCES appliances(id) ON DELETE SET NULL
);`,
`CREATE TABLE IF NOT EXISTS subnets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE,
network TEXT,
netmask TEXT,
gateway TEXT,
dns_zone TEXT,
is_private INTEGER DEFAULT 1
);`,
`CREATE TABLE IF NOT EXISTS networks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
node_id INTEGER NOT NULL,
name TEXT,
ip TEXT UNIQUE,
mac TEXT UNIQUE,
subnet_id INTEGER,
interface TEXT DEFAULT 'eth0',
FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE,
FOREIGN KEY (subnet_id) REFERENCES subnets(id) ON DELETE SET NULL
);`,
`CREATE TABLE IF NOT EXISTS aliases (
id INTEGER PRIMARY KEY AUTOINCREMENT,
node_id INTEGER NOT NULL,
name TEXT NOT NULL,
FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE,
UNIQUE(node_id, name)
);`,
`CREATE TABLE IF NOT EXISTS categories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE
);`,
`CREATE TABLE IF NOT EXISTS catindexes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
category_id INTEGER NOT NULL,
FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE,
UNIQUE(name, category_id)
);`,
`CREATE TABLE IF NOT EXISTS attributes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
attr TEXT NOT NULL,
value TEXT,
category_id INTEGER NOT NULL,
catindex_id INTEGER NOT NULL,
FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE,
FOREIGN KEY (catindex_id) REFERENCES catindexes(id) ON DELETE CASCADE,
UNIQUE(attr, category_id, catindex_id)
);`,
`CREATE TABLE IF NOT EXISTS node_attrs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
node_id INTEGER NOT NULL,
attr TEXT NOT NULL,
value TEXT,
FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE,
UNIQUE(node_id, attr)
);`,
`CREATE TABLE IF NOT EXISTS hostselections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
host_id INTEGER NOT NULL,
category_id INTEGER NOT NULL,
selection TEXT NOT NULL,
FOREIGN KEY (host_id) REFERENCES nodes(id) ON DELETE CASCADE,
FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE,
UNIQUE(host_id, category_id, selection)
);`,
`CREATE TABLE IF NOT EXISTS resolvechain (
id INTEGER PRIMARY KEY AUTOINCREMENT,
category_id INTEGER NOT NULL,
precedence INTEGER NOT NULL,
FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE CASCADE,
UNIQUE(category_id, precedence)
);`,
`CREATE TABLE IF NOT EXISTS software_installs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
version TEXT,
install_type TEXT,
node_id INTEGER,
status TEXT,
installed_at DATETIME DEFAULT CURRENT_TIMESTAMP,
installed_by TEXT,
FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE SET NULL
);`,
// 创建索引
`CREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name);`,
`CREATE INDEX IF NOT EXISTS idx_networks_ip ON networks(ip);`,
`CREATE INDEX IF NOT EXISTS idx_networks_mac ON networks(mac);`,
`CREATE INDEX IF NOT EXISTS idx_attributes_lookup ON attributes(attr, category_id, catindex_id);`,
`CREATE INDEX IF NOT EXISTS idx_node_attrs_lookup ON node_attrs(node_id, attr);`,
`CREATE INDEX IF NOT EXISTS idx_hostselections_host ON hostselections(host_id);`,
`CREATE INDEX IF NOT EXISTS idx_resolvechain_precedence ON resolvechain(precedence);`,
}
// 逐条执行SQL
for i, sql := range sqls {
if strings.TrimSpace(sql) == "" {
continue
}
log.Debugf("执行SQL[%d]: %s", i, strings.TrimSpace(strings.Split(sql, "\n")[0]))
_, err := tx.Exec(sql)
if err != nil {
tx.Rollback()
return fmt.Errorf("执行SQL[%d]失败: %v\nSQL: %s", i, err, sql)
}
}
// 提交事务
if err := tx.Commit(); err != nil {
return fmt.Errorf("提交事务失败: %v", err)
}
log.Info("数据库表结构创建成功")
// 插入默认数据
return db.insertDefaultData()
}
// insertDefaultData 插入默认数据
func (db *DB) insertDefaultData() error {
log.Debug("插入默认数据...")
// 默认类别
categories := []string{"global", "host", "os", "appliance", "network"}
for _, cat := range categories {
_, err := db.engine.Exec(
"INSERT OR IGNORE INTO categories (name) VALUES (?)",
cat,
)
if err != nil {
return err
}
}
log.Debug("插入默认类别索引...")
// 默认类别索引
catIndexes := []struct {
catName string
idxName string
}{
{"global", "global"},
{"os", "linux"},
{"network", "private"},
}
for _, ci := range catIndexes {
_, err := db.engine.Exec(`
INSERT OR IGNORE INTO catindexes (name, category_id)
SELECT ?, id FROM categories WHERE name = ?
`, ci.idxName, ci.catName)
if err != nil {
return err
}
}
log.Debug("插入默认解析链优先级...")
// 默认解析链优先级
precedence := []struct {
catName string
level int
}{
{"global", 1},
{"os", 2},
{"appliance", 3},
{"host", 4},
{"network", 5},
}
for _, p := range precedence {
_, err := db.engine.Exec(`
INSERT OR IGNORE INTO resolvechain (category_id, precedence)
SELECT id, ? FROM categories WHERE name = ?
`, p.level, p.catName)
if err != nil {
return err
}
}
log.Debug("插入默认设备类型...")
// 默认设备类型
appliances := []struct {
name string
desc string
typ string
}{
{"frontend", "管理节点", "master"},
{"compute", "计算节点", "compute"},
{"login", "登录节点", "login"},
{"storage", "存储节点", "storage"},
}
for _, a := range appliances {
_, err := db.engine.Exec(
"INSERT OR IGNORE INTO appliances (name, description, node_type) VALUES (?, ?, ?)",
a.name, a.desc, a.typ,
)
if err != nil {
return err
}
}
log.Debug("插入默认数据完成...")
return nil
}
// ==================== 核心查询方法 ====================
// Execute 执行SQL语句 - 对应Rocks的execute()
func (db *DB) Execute(query string, args ...interface{}) (int64, error) {
db.mu.RLock()
conn := db.conn
verbose := db.verbose
db.mu.RUnlock()
if conn == nil {
return 0, fmt.Errorf("没有活动数据库连接")
}
if verbose {
log.Debugf("执行SQL: %s %v", query, args)
}
// 判断SQL类型
upperQuery := strings.ToUpper(strings.TrimSpace(query))
isSelect := strings.HasPrefix(upperQuery, "SELECT")
if isSelect {
// SELECT 查询使用 QueryContext
rows, err := conn.QueryContext(context.Background(), query, args...)
if err != nil {
// 尝试重连一次
db.RenewConnection()
db.mu.RLock()
conn = db.conn
db.mu.RUnlock()
rows, err = conn.QueryContext(context.Background(), query, args...)
}
if err != nil {
return 0, err
}
// 关闭旧结果
db.mu.Lock()
if db.results != nil {
db.results.Close()
}
db.results = rows
db.mu.Unlock()
return 0, nil
} else {
// INSERT/UPDATE/DELETE 使用 Exec自动提交
result, err := conn.ExecContext(context.Background(), query, args...)
if err != nil {
// 尝试重连一次
db.RenewConnection()
db.mu.RLock()
conn = db.conn
db.mu.RUnlock()
result, err = conn.ExecContext(context.Background(), query, args...)
}
if err != nil {
return 0, err
}
// 获取影响行数
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, err
}
if verbose {
log.Debugf("影响行数: %d", rowsAffected)
}
return rowsAffected, nil
}
}
// FetchOne 获取一行 - 对应Rocks的fetchone()
// 返回map[string]interface{}格式key为列名
func (db *DB) FetchOne() (map[string]interface{}, error) {
db.mu.RLock()
results := db.results
db.mu.RUnlock()
if results == nil {
return nil, nil
}
if !results.Next() {
return nil, nil
}
columns, err := results.Columns()
if err != nil {
return nil, err
}
values := make([]interface{}, len(columns))
scanArgs := make([]interface{}, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
err = results.Scan(scanArgs...)
if err != nil {
return nil, err
}
row := make(map[string]interface{})
for i, col := range columns {
val := values[i]
if b, ok := val.([]byte); ok {
row[col] = string(b)
} else {
row[col] = val
}
}
return row, nil
}
// FetchAll 获取所有行 - 对应Rocks的fetchall()
// 返回[]map[string]interface{}格式
func (db *DB) FetchAll() ([]map[string]interface{}, error) {
db.mu.RLock()
results := db.results
db.mu.RUnlock()
if results == nil {
return nil, nil
}
columns, err := results.Columns()
if err != nil {
return nil, err
}
var rows []map[string]interface{}
for results.Next() {
values := make([]interface{}, len(columns))
scanArgs := make([]interface{}, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
err = results.Scan(scanArgs...)
if err != nil {
return nil, err
}
row := make(map[string]interface{})
for i, col := range columns {
val := values[i]
if b, ok := val.([]byte); ok {
row[col] = string(b)
} else {
row[col] = val
}
}
rows = append(rows, row)
}
return rows, nil
}
// ==================== 连接维护 ====================
// RenewConnection 续期连接
func (db *DB) RenewConnection() error {
db.mu.Lock()
defer db.mu.Unlock()
if db.conn != nil {
db.conn.Close()
}
conn, err := db.engine.Conn(context.Background())
if err != nil {
return err
}
db.conn = conn
return nil
}
// Close 关闭连接
func (db *DB) Close() error {
db.mu.Lock()
defer db.mu.Unlock()
if db.results != nil {
db.results.Close()
db.results = nil
}
if db.conn != nil {
db.conn.Close()
db.conn = nil
}
if db.engine != nil {
return db.engine.Close()
}
return nil
}
// CloseConnection 只关闭当前连接,不关闭连接池
func (db *DB) CloseConnection() error {
db.mu.Lock()
defer db.mu.Unlock()
if db.results != nil {
db.results.Close()
db.results = nil
}
if db.conn != nil {
db.conn.Close()
db.conn = nil
}
return nil
}
// ==================== 单例模式 ====================
var (
instanceConfigured bool
instanceDBPath string
instanceDBName string
)
func GetInstance() (*DB, error) {
return GetInstanceWithConfig("", "")
}
func GetInstanceWithConfig(dbPath, dbName string) (*DB, error) {
var err error
once.Do(func() {
globalDB = NewDB()
log.Debug("创建数据库实例...")
globalDB.SetDBUsername(globalDB.GetDBUsername())
if dbPath != "" {
globalDB.SetDBPath(dbPath)
log.Debugf("设置数据库路径: %s", dbPath)
}
if dbName != "" {
globalDB.SetDBName(dbName)
log.Debugf("设置数据库名称: %s", dbName)
}
instanceConfigured = (dbPath != "" || dbName != "")
if dbPath != "" {
instanceDBPath = dbPath
}
if dbName != "" {
instanceDBName = dbName
}
err = globalDB.Connect()
})
return globalDB, err
}
func IsInstanceConfigured() bool {
return instanceConfigured
}
func GetInstanceConfig() (dbPath, dbName string) {
return instanceDBPath, instanceDBName
}

624
internal/db/helper.go Normal file
View File

@@ -0,0 +1,624 @@
package db
import (
"fmt"
"net"
"os"
"strings"
"sync"
)
/*
// 获取数据库实例
database, err := db.GetInstance()
if err != nil {
log.Fatal(err)
}
defer database.Close()
// 创建Helper
helper, _ := db.NewDBHelper()
// 执行查询
helper.Execute("SELECT * FROM nodes WHERE rack = ?", 1)
// 获取一行
row, _ := helper.FetchOne()
if row != nil {
log.Infof("节点: %v", row["name"])
}
// 获取所有行
rows, _ := helper.FetchAll()
log.Infof("共 %d 个节点", len(rows))
// 使用Helper方法
hostname, _ := helper.GetHostname("192.168.1.1")
log.Infof("解析主机名: %s", hostname)
// 设置属性
helper.SetCategoryAttr("global", "global", "Kickstart_PrivateHostname", "sunhpc-master")
// 获取属性
val := helper.GetCategoryAttr("global", "global", "Kickstart_PrivateHostname")
log.Infof("前端主机名: %s", val)
*/
const (
attrPostfix = "_old"
)
// DBHelper DatabaseHelper类 - 继承DB扩展业务方法
type DBHelper struct {
*DB
appliancesList []string
frontendName string
cacheAttrs sync.Map
}
func NewDBHelper() (*DBHelper, error) {
db, err := GetInstance()
if err != nil {
return nil, err
}
return &DBHelper{
DB: db,
appliancesList: nil,
frontendName: "",
}, nil
}
// ==================== 节点查询 ====================
// GetListHostnames 获取所有主机名列表
func (h *DBHelper) GetListHostnames() ([]string, error) {
_, err := h.Execute("SELECT name FROM nodes ORDER BY name")
if err != nil {
return nil, err
}
rows, err := h.FetchAll()
if err != nil {
return nil, err
}
var names []string
for _, row := range rows {
if name, ok := row["name"]; ok {
names = append(names, name.(string))
}
}
return names, nil
}
// GetNodesFromNames 从名称列表获取节点
func (h *DBHelper) GetNodesFromNames(names []string, managedOnly bool) ([]map[string]interface{}, error) {
// 如果没有提供名称,返回所有节点
if len(names) == 0 {
query := "SELECT * FROM nodes"
if managedOnly {
query = `
SELECT n.* FROM nodes n
JOIN node_attrs a ON n.id = a.node_id
WHERE a.attr = 'managed' AND a.value = 'true'
`
}
_, err := h.Execute(query)
if err != nil {
return nil, err
}
return h.FetchAll()
}
// 构建查询条件
conditions := []string{}
args := []interface{}{}
for _, name := range names {
if strings.HasPrefix(name, "select ") {
conditions = append(conditions, fmt.Sprintf("name IN (%s)", name[7:]))
} else if strings.Contains(name, "%") {
conditions = append(conditions, "name LIKE ?")
args = append(args, name)
} else if strings.HasPrefix(name, "rack") {
rackNum := strings.TrimPrefix(name, "rack")
conditions = append(conditions, "rack = ?")
args = append(args, rackNum)
} else if h.IsApplianceName(name) {
conditions = append(conditions, `id IN (
SELECT node_id FROM node_attrs
WHERE attr = 'appliance' AND value = ?
)`)
args = append(args, name)
} else {
hostname, err := h.GetHostname(name)
if err == nil {
conditions = append(conditions, "name = ?")
args = append(args, hostname)
}
}
}
if len(conditions) == 0 {
return []map[string]interface{}{}, nil
}
query := "SELECT * FROM nodes WHERE " + strings.Join(conditions, " OR ")
_, err := h.Execute(query, args...)
if err != nil {
return nil, err
}
nodes, err := h.FetchAll()
if err != nil {
return nil, err
}
// 过滤受管节点
if managedOnly {
var managed []map[string]interface{}
for _, node := range nodes {
val := h.GetHostAttr(node["name"].(string), "managed")
if val == "true" {
managed = append(managed, node)
}
}
return managed, nil
}
return nodes, nil
}
// ==================== 设备类型 ====================
// GetAppliancesListText 获取所有设备类型名称
func (h *DBHelper) GetAppliancesListText() []string {
if h.appliancesList != nil {
return h.appliancesList
}
_, err := h.Execute("SELECT DISTINCT value FROM node_attrs WHERE attr = 'appliance'")
if err != nil {
return []string{}
}
rows, _ := h.FetchAll()
var apps []string
for _, row := range rows {
if val, ok := row["value"]; ok {
apps = append(apps, val.(string))
}
}
h.appliancesList = apps
return apps
}
// IsApplianceName 检查是否为设备类型名称
func (h *DBHelper) IsApplianceName(name string) bool {
for _, app := range h.GetAppliancesListText() {
if app == name {
return true
}
}
return false
}
// ==================== 主机名解析 ====================
// GetHostname 规范化主机名 - 完全参考Rocks实现
func (h *DBHelper) GetHostname(hostname string) (string, error) {
// 如果hostname为空使用系统主机名
if hostname == "" {
hostname, _ = os.Hostname()
hostname = strings.Split(hostname, ".")[0]
return h.GetHostname(hostname)
}
// 1. 直接在nodes表中查找
_, err := h.Execute("SELECT * FROM nodes WHERE name = ?", hostname)
if err == nil {
row, _ := h.FetchOne()
if row != nil {
return hostname, nil
}
}
// 2. 尝试IP地址反向解析
addr := net.ParseIP(hostname)
if addr != nil {
names, err := net.LookupAddr(hostname)
if err == nil && len(names) > 0 {
return h.GetHostname(strings.Split(names[0], ".")[0])
}
}
// 3. 在networks表中查找IP
if addr != nil {
_, err := h.Execute(`
SELECT n.name FROM nodes n
JOIN networks net ON n.id = net.node_id
WHERE net.ip = ?
`, addr.String())
if err == nil {
row, _ := h.FetchOne()
if row != nil {
return row["name"].(string), nil
}
}
}
// 4. 尝试MAC地址
mac := strings.ReplaceAll(hostname, "-", ":")
_, err = h.Execute(`
SELECT n.name FROM nodes n
JOIN networks net ON n.id = net.node_id
WHERE net.mac = ?
`, mac)
if err == nil {
row, _ := h.FetchOne()
if row != nil {
return row["name"].(string), nil
}
}
// 5. 检查别名
_, err = h.Execute(`
SELECT n.name FROM nodes n
JOIN aliases a ON n.id = a.node_id
WHERE a.name = ?
`, hostname)
if err == nil {
row, _ := h.FetchOne()
if row != nil {
return row["name"].(string), nil
}
}
// 6. 尝试FQDN
if strings.Contains(hostname, ".") {
parts := strings.Split(hostname, ".")
name := parts[0]
domain := strings.Join(parts[1:], ".")
_, err := h.Execute(`
SELECT n.name FROM nodes n
JOIN networks net ON n.id = net.node_id
JOIN subnets s ON net.subnet_id = s.id
WHERE s.dns_zone = ? AND (net.name = ? OR n.name = ?)
`, domain, name, name)
if err == nil {
row, _ := h.FetchOne()
if row != nil {
return row["name"].(string), nil
}
}
}
// 7. 如果以上都失败,抛出异常
return "", fmt.Errorf("无法解析主机名: %s", hostname)
}
// CheckHostnameValidity 检查主机名有效性
func (h *DBHelper) CheckHostnameValidity(hostname string) error {
// 不能包含点
if strings.Contains(hostname, ".") {
return fmt.Errorf("主机名 %s 不能包含点号", hostname)
}
// 不能是rack<数字>格式
if strings.HasPrefix(hostname, "rack") {
num := strings.TrimPrefix(hostname, "rack")
if _, err := fmt.Sscanf(num, "%d", new(int)); err == nil {
return fmt.Errorf("主机名 %s 不能是rack<数字>格式", hostname)
}
}
// 不能是设备类型名称
if h.IsApplianceName(hostname) {
return fmt.Errorf("主机名 %s 不能与设备类型名称相同", hostname)
}
// 检查是否已存在
_, err := h.GetHostname(hostname)
if err == nil {
return fmt.Errorf("节点 %s 已存在", hostname)
}
return nil
}
// ==================== 前端节点 ====================
// GetFrontendName 获取前端节点名称
func (h *DBHelper) GetFrontendName() string {
if h.frontendName != "" {
return h.frontendName
}
name := h.GetCategoryAttr("global", "global", "Kickstart_PrivateHostname")
if name != "" {
h.frontendName = name
}
return h.frontendName
}
// ==================== 属性管理 ====================
// GetCategoryIndex 获取类别索引
func (h *DBHelper) GetCategoryIndex(categoryName, categoryIndex string) (map[string]interface{}, map[string]interface{}, error) {
// 查询类别和索引
_, err := h.Execute(`
SELECT c.id as cid, c.name as cname, i.id as iid, i.name as iname
FROM categories c
JOIN catindexes i ON c.id = i.category_id
WHERE c.name = ? AND i.name = ?
`, categoryName, categoryIndex)
if err == nil {
row, _ := h.FetchOne()
if row != nil {
category := map[string]interface{}{
"id": row["cid"],
"name": row["cname"],
}
catindex := map[string]interface{}{
"id": row["iid"],
"name": row["iname"],
"category_id": row["cid"],
}
return category, catindex, nil
}
}
// 不存在则创建
// 创建类别
_, err = h.Execute("INSERT INTO categories (name) VALUES (?)", categoryName)
if err != nil {
return nil, nil, err
}
var catID int64
h.Execute("SELECT last_insert_rowid()")
row, _ := h.FetchOne()
if row != nil {
catID = row["last_insert_rowid()"].(int64)
}
// 创建索引
_, err = h.Execute(
"INSERT INTO catindexes (name, category_id) VALUES (?, ?)",
categoryIndex, catID,
)
if err != nil {
return nil, nil, err
}
h.Execute("SELECT last_insert_rowid()")
row, _ = h.FetchOne()
var idxID int64
if row != nil {
idxID = row["last_insert_rowid()"].(int64)
}
category := map[string]interface{}{
"id": catID,
"name": categoryName,
}
catindex := map[string]interface{}{
"id": idxID,
"name": categoryIndex,
"category_id": catID,
}
return category, catindex, nil
}
// SetCategoryAttr 设置类别属性
func (h *DBHelper) SetCategoryAttr(categoryName, catindexName, attr, value string) error {
cat, catindex, err := h.GetCategoryIndex(categoryName, catindexName)
if err != nil {
return err
}
// 查询现有属性
_, err = h.Execute(`
SELECT id, value FROM attributes
WHERE attr = ? AND category_id = ? AND catindex_id = ?
`, attr, cat["id"], catindex["id"])
if err == nil {
row, _ := h.FetchOne()
if row != nil {
// 更新现有属性
oldValue := row["value"]
attrID := row["id"]
_, err = h.Execute(
"UPDATE attributes SET value = ? WHERE id = ?",
value, attrID,
)
if err != nil {
return err
}
// 保存旧值
if !strings.HasSuffix(attr, attrPostfix) {
h.SetCategoryAttr(categoryName, catindexName, attr+attrPostfix, oldValue.(string))
}
return nil
}
}
// 创建新属性
_, err = h.Execute(`
INSERT INTO attributes (attr, value, category_id, catindex_id)
VALUES (?, ?, ?, ?)
`, attr, value, cat["id"], catindex["id"])
return err
}
// GetCategoryAttr 获取类别属性
func (h *DBHelper) GetCategoryAttr(categoryName, catindexName, attrName string) string {
cat, catindex, err := h.GetCategoryIndex(categoryName, catindexName)
if err != nil {
return ""
}
_, err = h.Execute(`
SELECT value FROM attributes
WHERE attr = ? AND category_id = ? AND catindex_id = ?
`, attrName, cat["id"], catindex["id"])
if err != nil {
return ""
}
row, _ := h.FetchOne()
if row == nil {
return ""
}
return row["value"].(string)
}
// RemoveCategoryAttr 移除类别属性
func (h *DBHelper) RemoveCategoryAttr(categoryName, catindexName, attrName string) error {
cat, catindex, err := h.GetCategoryIndex(categoryName, catindexName)
if err != nil {
return err
}
_, err = h.Execute(`
DELETE FROM attributes
WHERE attr = ? AND category_id = ? AND catindex_id = ?
`, attrName, cat["id"], catindex["id"])
if err != nil {
return err
}
// 同时删除对应的_old属性
_, _ = h.Execute(`
DELETE FROM attributes
WHERE attr = ? AND category_id = ? AND catindex_id = ?
`, attrName+attrPostfix, cat["id"], catindex["id"])
return nil
}
// ==================== 主机属性 ====================
// GetHostAttr 获取主机属性
func (h *DBHelper) GetHostAttr(hostname, attr string) string {
// 先从节点直接属性查询
_, err := h.Execute(`
SELECT value FROM node_attrs
WHERE node_id = (SELECT id FROM nodes WHERE name = ?)
AND attr = ?
`, hostname, attr)
if err == nil {
row, _ := h.FetchOne()
if row != nil {
return row["value"].(string)
}
}
// 使用Rocks的属性解析链查询
query := `
SELECT a.value FROM attributes a
JOIN resolvechain r ON a.category_id = r.category_id
JOIN hostselections h ON a.category_id = h.category_id
AND a.catindex_id = h.selection
WHERE h.host_id = (SELECT id FROM nodes WHERE name = ?)
AND a.attr = ?
ORDER BY r.precedence DESC
LIMIT 1
`
_, err = h.Execute(query, hostname, attr)
if err != nil {
return ""
}
row, _ := h.FetchOne()
if row == nil {
return ""
}
return row["value"].(string)
}
// GetHostAttrs 获取主机所有属性
func (h *DBHelper) GetHostAttrs(hostname string, showSource bool) map[string]interface{} {
attrs := make(map[string]interface{})
// 获取节点基本信息
_, err := h.Execute(`
SELECT n.id, n.name, n.rack, n.rank, m.name as membership, a.name as appliance
FROM nodes n
LEFT JOIN memberships m ON n.membership_id = m.id
LEFT JOIN appliances a ON m.appliance_id = a.id
WHERE n.name = ?
`, hostname)
if err == nil {
row, _ := h.FetchOne()
if row != nil {
if showSource {
attrs["hostname"] = []interface{}{row["name"], "I"}
attrs["rack"] = []interface{}{row["rack"], "I"}
attrs["rank"] = []interface{}{row["rank"], "I"}
attrs["appliance"] = []interface{}{row["appliance"], "I"}
attrs["membership"] = []interface{}{row["membership"], "I"}
} else {
attrs["hostname"] = row["name"]
attrs["rack"] = row["rack"]
attrs["rank"] = row["rank"]
attrs["appliance"] = row["appliance"]
attrs["membership"] = row["membership"]
}
}
}
// 获取所有属性
query := `
SELECT a.attr, a.value,
CASE
WHEN h.host_id IS NOT NULL THEN 'H'
ELSE UPPER(SUBSTR(c.name, 1, 1))
END as source
FROM attributes a
JOIN categories c ON a.category_id = c.id
LEFT JOIN hostselections h ON a.category_id = h.category_id
AND a.catindex_id = h.selection
AND h.host_id = (SELECT id FROM nodes WHERE name = ?)
UNION
SELECT attr, value, 'N' as source
FROM node_attrs
WHERE node_id = (SELECT id FROM nodes WHERE name = ?)
`
_, err = h.Execute(query, hostname, hostname)
if err == nil {
rows, _ := h.FetchAll()
for _, row := range rows {
attr := row["attr"].(string)
value := row["value"]
if showSource {
attrs[attr] = []interface{}{value, row["source"]}
} else {
attrs[attr] = value
}
}
}
return attrs
}

344
internal/log/logger.go Normal file
View File

@@ -0,0 +1,344 @@
package log
import (
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"sync"
"time"
"github.com/fatih/color"
)
// 日志级别
type Level int
const (
DebugLevel Level = iota
InfoLevel
WarnLevel
ErrorLevel
FatalLevel
)
// 级别名称
var levelNames = map[Level]string{
DebugLevel: "DEBUG",
InfoLevel: "INFO",
WarnLevel: "WARN",
ErrorLevel: "ERROR",
FatalLevel: "FATAL",
}
// 级别简写
var levelShort = map[Level]string{
DebugLevel: "[d]",
InfoLevel: "[i]",
WarnLevel: "[w]",
ErrorLevel: "[e]",
FatalLevel: "[f]",
}
// 级别颜色
var levelColor = map[Level]func(format string, a ...interface{}) string{
DebugLevel: color.CyanString, // 青色
InfoLevel: color.GreenString, // 绿色
WarnLevel: color.YellowString, // 黄色
ErrorLevel: color.RedString, // 红色
FatalLevel: color.MagentaString, // 品红
}
// Logger 日志器结构体
type Logger struct {
mu sync.Mutex
consoleOut io.Writer // 控制台输出
fileOut io.Writer // 文件输出
minLevel Level // 最小输出级别
showColor bool // 是否显示颜色
showCaller bool // 是否显示调用者信息
callerSkip int // 调用者跳过的层级
timeFormat string // 时间格式
}
// 默认日志器实例
var defaultLogger *Logger
const (
defaultTimeFormat = "2006-01-02 15:04:05"
logFile = "/var/log/sunhpc/sunhpc.log"
)
// Init 初始化日志系统
func Init(verbose bool) {
// 确保日志目录存在
if err := os.MkdirAll(filepath.Dir(logFile), 0755); err != nil {
fmt.Fprintf(os.Stderr, "创建日志目录失败: %v\n", err)
os.Exit(1)
}
// 打开日志文件
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
fmt.Fprintf(os.Stderr, "打开日志文件失败: %v\n", err)
os.Exit(1)
}
// 控制台输出
consoleOut := os.Stdout
// 创建日志器
defaultLogger = &Logger{
consoleOut: consoleOut,
fileOut: file,
minLevel: InfoLevel,
showColor: true,
showCaller: false,
callerSkip: 2,
timeFormat: defaultTimeFormat,
}
// 详细模式下显示调试信息
if verbose {
defaultLogger.minLevel = DebugLevel
defaultLogger.showCaller = true
}
// 初始化颜色支持
if runtime.GOOS == "windows" {
color.NoColor = false
}
}
// log 核心日志输出方法
func (l *Logger) log(level Level, format string, args ...interface{}) {
if level < l.minLevel {
return
}
l.mu.Lock()
defer l.mu.Unlock()
// 生成时间戳
timestamp := time.Now().Format(l.timeFormat)
// 获取调用者信息
caller := ""
if l.showCaller {
_, file, line, ok := runtime.Caller(l.callerSkip)
if ok {
// 只保留文件名和行号
file = filepath.Base(file)
caller = fmt.Sprintf(" %s:%d", file, line)
}
}
// 格式化消息
var message string
if format == "" {
message = fmt.Sprint(args...)
} else {
message = fmt.Sprintf(format, args...)
}
// ---- 控制台输出(带颜色和简写)----
if l.consoleOut != nil {
// 获取级别简写
shortPrefix := levelShort[level]
// 构建控制台行
var consoleLine string
if l.showColor {
// 带颜色输出 - 简写有颜色,时间戳灰色
colorFunc := levelColor[level]
consoleLine = fmt.Sprintf("%s %s %s",
color.HiBlackString(timestamp), // 时间戳灰色
colorFunc(shortPrefix), // 级别简写彩色
message) // 消息普通颜色
} else {
// 不带颜色输出
consoleLine = fmt.Sprintf("%s %s %s",
timestamp,
shortPrefix,
message)
}
// 添加调用者信息(灰色)
if caller != "" {
if l.showColor {
consoleLine += fmt.Sprintf(" %s", color.HiBlackString(caller))
} else {
consoleLine += fmt.Sprintf(" %s", caller)
}
}
fmt.Fprintln(l.consoleOut, consoleLine)
}
// ---- 文件输出(完整格式)----
if l.fileOut != nil {
// 获取级别全名
levelName := levelNames[level]
// 文件使用完整格式:时间 [级别] 消息 调用者
fileLine := fmt.Sprintf("%s [%s] %s%s\n",
timestamp,
levelName,
message,
caller)
fmt.Fprint(l.fileOut, fileLine)
}
// 致命错误退出程序
if level == FatalLevel {
os.Exit(1)
}
}
// 全局日志函数
// Debug 调试日志
func Debug(args ...interface{}) {
if defaultLogger != nil {
defaultLogger.log(DebugLevel, "", args...)
}
}
// Debugf 格式化调试日志
func Debugf(format string, args ...interface{}) {
if defaultLogger != nil {
defaultLogger.log(DebugLevel, format, args...)
}
}
// Info 信息日志
func Info(args ...interface{}) {
if defaultLogger != nil {
defaultLogger.log(InfoLevel, "", args...)
}
}
// Infof 格式化信息日志
func Infof(format string, args ...interface{}) {
if defaultLogger != nil {
defaultLogger.log(InfoLevel, format, args...)
}
}
// Warn 警告日志
func Warn(args ...interface{}) {
if defaultLogger != nil {
defaultLogger.log(WarnLevel, "", args...)
}
}
// Warnf 格式化警告日志
func Warnf(format string, args ...interface{}) {
if defaultLogger != nil {
defaultLogger.log(WarnLevel, format, args...)
}
}
// Error 错误日志
func Error(args ...interface{}) {
if defaultLogger != nil {
defaultLogger.log(ErrorLevel, "", args...)
}
}
// Errorf 格式化错误日志
func Errorf(format string, args ...interface{}) {
if defaultLogger != nil {
defaultLogger.log(ErrorLevel, format, args...)
}
}
// Fatal 致命错误日志,输出后退出程序
func Fatal(args ...interface{}) {
if defaultLogger != nil {
defaultLogger.log(FatalLevel, "", args...)
}
}
// Fatalf 格式化致命错误日志,输出后退出程序
func Fatalf(format string, args ...interface{}) {
if defaultLogger != nil {
defaultLogger.log(FatalLevel, format, args...)
}
}
// Writer 返回一个 io.Writer可将子命令的输出写入日志Debug级别
func Writer() *io.PipeWriter {
r, w := io.Pipe()
go func() {
buf := make([]byte, 1024)
for {
n, err := r.Read(buf)
if n > 0 {
Debug(string(buf[:n]))
}
if err != nil {
break
}
}
}()
return w
}
// SetLevel 设置日志级别
func SetLevel(level Level) {
if defaultLogger != nil {
defaultLogger.mu.Lock()
defer defaultLogger.mu.Unlock()
defaultLogger.minLevel = level
}
}
// EnableColor 启用/禁用颜色输出
func EnableColor(enable bool) {
if defaultLogger != nil {
defaultLogger.mu.Lock()
defer defaultLogger.mu.Unlock()
defaultLogger.showColor = enable
}
}
// EnableCaller 启用/禁用调用者信息
func EnableCaller(enable bool) {
if defaultLogger != nil {
defaultLogger.mu.Lock()
defer defaultLogger.mu.Unlock()
defaultLogger.showCaller = enable
}
}
// SetTimeFormat 设置时间格式
func SetTimeFormat(format string) {
if defaultLogger != nil {
defaultLogger.mu.Lock()
defer defaultLogger.mu.Unlock()
defaultLogger.timeFormat = format
}
}
// Sync 同步日志文件
func Sync() {
if defaultLogger != nil && defaultLogger.fileOut != nil {
if f, ok := defaultLogger.fileOut.(*os.File); ok {
f.Sync()
}
}
}
// Close 关闭日志文件
func Close() error {
if defaultLogger != nil && defaultLogger.fileOut != nil {
if f, ok := defaultLogger.fileOut.(*os.File); ok {
return f.Close()
}
}
return nil
}

View File

@@ -0,0 +1,25 @@
package service
import (
"fmt"
"sunhpc/internal/config"
"sunhpc/internal/log"
"sunhpc/internal/template"
)
func Deploy(cfg *config.ServicesConfig) error {
// 示例:使用模板部署 DHCPD
if cfg.DHCPD.Enabled {
log.Info("部署 DHCPD 服务...")
// 从模板渲染配置文件
err := template.RenderAndExecute("dhcpd.conf.tmpl", map[string]interface{}{
"Subnet": "192.168.1.0",
"Netmask": "255.255.255.0",
})
if err != nil {
return fmt.Errorf("DHCPD 配置失败: %v", err)
}
// 实际部署逻辑(启动服务等)...
}
return nil
}

33
internal/soft/binary.go Normal file
View File

@@ -0,0 +1,33 @@
package soft
import (
"fmt"
"os"
"os/exec"
"strings"
)
// extractBinary 解压二进制压缩包到目标目录
func extractBinary(binPath, destDir string) error {
// 确保目标目录存在
if err := os.MkdirAll(destDir, 0755); err != nil {
return err
}
// 根据扩展名选择解压命令
var cmd *exec.Cmd
switch {
case strings.HasSuffix(binPath, ".tar.gz"), strings.HasSuffix(binPath, ".tgz"):
cmd = exec.Command("tar", "xzf", binPath, "-C", destDir)
case strings.HasSuffix(binPath, ".tar.bz2"):
cmd = exec.Command("tar", "xjf", binPath, "-C", destDir)
case strings.HasSuffix(binPath, ".zip"):
cmd = exec.Command("unzip", binPath, "-d", destDir)
default:
return fmt.Errorf("不支持的压缩格式: %s", binPath)
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}

49
internal/soft/manager.go Normal file
View File

@@ -0,0 +1,49 @@
package soft
import (
"sunhpc/internal/log"
)
// InstallContext 安装上下文,包含所有命令行参数
type InstallContext struct {
Force bool // 强制安装
DryRun bool // 干运行模式
KeepSource bool // 保留源码文件
Jobs int // 编译线程数
Offline bool // 离线模式
}
// InstallFromSource 从源码编译安装
func InstallFromSource(name, srcPath, prefix, version string, ctx *InstallContext) error {
log.Infof("正在从源码安装 %s路径: %s", name, srcPath)
if ctx != nil && ctx.DryRun {
log.Infof("[干运行] 将执行: configure && make -j%d && make install", ctx.Jobs)
return nil
}
// TODO: 实现具体逻辑:下载、解压、./configure、make、make install
log.Info("源码安装模拟完成(需实现具体步骤)")
return nil
}
// InstallFromBinary 从二进制压缩包安装
func InstallFromBinary(name, binPath, prefix string, ctx *InstallContext) error {
log.Infof("正在安装二进制包 %s路径: %s", name, binPath)
if ctx != nil && ctx.DryRun {
log.Infof("[干运行] 将解压 %s 到 %s", binPath, prefix)
return nil
}
// TODO: 解压到 prefix
log.Info("二进制安装模拟完成(需实现具体步骤)")
return nil
}
// InstallFromPackage 通过系统包管理器安装
func InstallFromPackage(name, pkgType string, ctx *InstallContext) error {
log.Infof("正在通过包管理器安装 %s (%s)", name, pkgType)
if ctx != nil && ctx.DryRun {
log.Infof("[干运行] 将执行包管理器安装 %s", name)
return nil
}
// 具体实现在下面的 package.go 中
return installViaPackageManager(name, pkgType)
}

38
internal/soft/package.go Normal file
View File

@@ -0,0 +1,38 @@
package soft
import (
"fmt"
"os"
"os/exec"
"sunhpc/internal/log"
"sunhpc/pkg/utils"
)
// installViaPackageManager 使用系统包管理器安装软件
func installViaPackageManager(name, pkgType string) error {
var cmd *exec.Cmd
switch pkgType {
case "rpm":
// RHEL/CentOS
if utils.CommandExists("yum") {
cmd = exec.Command("yum", "install", "-y", name)
} else if utils.CommandExists("dnf") {
cmd = exec.Command("dnf", "install", "-y", name)
} else {
return fmt.Errorf("未找到 yum 或 dnf 包管理器")
}
case "deb":
// Debian/Ubuntu
if !utils.CommandExists("apt-get") {
return fmt.Errorf("未找到 apt-get 包管理器")
}
cmd = exec.Command("apt-get", "install", "-y", name)
default:
return fmt.Errorf("不支持的包类型: %s", pkgType)
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
log.Infof("执行命令: %s", cmd.String())
return cmd.Run()
}

48
internal/soft/source.go Normal file
View File

@@ -0,0 +1,48 @@
package soft
import (
"fmt"
"os"
"os/exec"
"sunhpc/internal/log"
"sunhpc/pkg/utils"
)
// compileFromSource 通用源码编译流程
func compileFromSource(srcDir, prefix string, jobs int) error {
// 切换到源码目录
if err := os.Chdir(srcDir); err != nil {
return fmt.Errorf("进入源码目录失败: %v", err)
}
// 检测 configure 脚本是否存在
if utils.FileExists("./configure") {
log.Debug("执行 configure ...")
cmd := exec.Command("./configure", "--prefix="+prefix)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("configure 失败: %v", err)
}
}
// make
log.Debugf("执行 make -j%d ...", jobs)
cmd := exec.Command("make", fmt.Sprintf("-j%d", jobs))
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("make 失败: %v", err)
}
// make install
log.Debug("执行 make install ...")
cmd = exec.Command("make", "install")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("make install 失败: %v", err)
}
return nil
}

View File

@@ -0,0 +1,97 @@
package system
import (
"fmt"
"os"
"os/exec"
"strings"
)
// SetHostname 设置系统主机名
// 参数: hostname - 目标主机名
// 返回: error - 如果设置失败返回错误信息
func SetHostname(hostname string) error {
if hostname == "" {
return nil // 空值跳过,不报错
}
// 检查是否已有相同主机名
current, err := os.Hostname()
if err == nil && current == hostname {
return nil // 已经设置正确,无需修改
}
// 使用 hostnamectl 设置主机名(适用于 systemd 系统)
if _, err := exec.LookPath("hostnamectl"); err == nil {
cmd := exec.Command("hostnamectl", "set-hostname", hostname)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("hostnamectl 设置主机名失败: %v", err)
}
} else {
// 传统方法:直接修改 /etc/hostname
if err := os.WriteFile("/etc/hostname", []byte(hostname+"\n"), 0644); err != nil {
return fmt.Errorf("写入 /etc/hostname 失败: %v", err)
}
// 立即生效(需要内核支持)
cmd := exec.Command("sysctl", "kernel.hostname="+hostname)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
// 不返回错误,因为重启后会生效
fmt.Printf("警告: 主机名将在重启后完全生效\n")
}
}
// 更新 /etc/hosts确保本机解析正确
if err := updateHostsFile(hostname); err != nil {
fmt.Printf("警告: 更新 /etc/hosts 失败: %v\n", err)
}
return nil
}
// updateHostsFile 更新 /etc/hosts 文件中的本机映射
func updateHostsFile(hostname string) error {
content, err := os.ReadFile("/etc/hosts")
if err != nil {
return err
}
lines := strings.Split(string(content), "\n")
newLines := make([]string, 0, len(lines))
hostnameSet := false
for _, line := range lines {
// 跳过空行和注释
if line == "" || strings.HasPrefix(line, "#") {
newLines = append(newLines, line)
continue
}
fields := strings.Fields(line)
if len(fields) >= 2 && fields[0] == "127.0.1.1" {
// 替换 Ubuntu/Debian 风格的本地主机名
newLines = append(newLines, "127.0.1.1\t"+hostname)
hostnameSet = true
} else if len(fields) >= 2 && fields[0] == "127.0.0.1" {
// 保留原行,但检查是否包含主机名
if !strings.Contains(line, hostname) {
line = line + " " + hostname
}
newLines = append(newLines, line)
hostnameSet = true
} else {
newLines = append(newLines, line)
}
}
// 如果没有找到合适的位置,添加一行
if !hostnameSet {
newLines = append(newLines, "127.0.1.1\t"+hostname)
}
return os.WriteFile("/etc/hosts", []byte(strings.Join(newLines, "\n")), 0644)
}

47
internal/system/motd.go Normal file
View File

@@ -0,0 +1,47 @@
package system
import (
"os"
"time"
)
// SetMOTD 设置 /etc/motd 文件内容
// 参数: content - MOTD 文本内容
// 返回: error - 写入文件错误
func SetMOTD(content string) error {
if content == "" {
// 如果内容为空,不清除现有 MOTD避免误操作
return nil
}
// 添加时间和系统信息
finalContent := "========================================\n"
finalContent += "SunHPC 集群管理系统\n"
finalContent += "时间: " + time.Now().Format("2006-01-02 15:04:05") + "\n"
finalContent += "========================================\n\n"
finalContent += content
// 确保行尾有换行
if content[len(content)-1] != '\n' {
finalContent += "\n"
}
return os.WriteFile("/etc/motd", []byte(finalContent), 0644)
}
// ClearMOTD 清空 MOTD
func ClearMOTD() error {
return os.WriteFile("/etc/motd", []byte{}, 0644)
}
// AppendToMOTD 追加内容到 MOTD
func AppendToMOTD(additional string) error {
f, err := os.OpenFile("/etc/motd", os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer f.Close()
_, err = f.WriteString(additional + "\n")
return err
}

View File

@@ -0,0 +1,97 @@
package system
import (
"fmt"
"os"
"os/exec"
"strings"
)
// ConfigureSELinux 设置 SELinux 模式
// 参数: mode - enforcing, permissive, disabled
// 返回: error - 配置错误
func ConfigureSELinux(mode string) error {
if mode == "" {
return nil
}
// 验证输入
mode = strings.ToLower(strings.TrimSpace(mode))
validModes := map[string]bool{
"enforcing": true,
"permissive": true,
"disabled": true,
}
if !validModes[mode] {
return fmt.Errorf("无效的 SELinux 模式: %s (可选: enforcing, permissive, disabled)", mode)
}
// 检查 SELinux 是否可用
if _, err := os.Stat("/selinux/enforce"); os.IsNotExist(err) {
if _, err := os.Stat("/sys/fs/selinux/enforce"); os.IsNotExist(err) {
return fmt.Errorf("系统不支持 SELinux 或未启用")
}
}
// 临时生效
if mode != "disabled" { // disabled 需要重启才能完全生效
cmd := exec.Command("setenforce", mode)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("setenforce %s 失败: %v", mode, err)
}
}
// 持久化配置
return persistSELinuxMode(mode)
}
// persistSELinuxMode 修改 /etc/selinux/config 实现持久化
func persistSELinuxMode(mode string) error {
const selinuxConfig = "/etc/selinux/config"
// 读取配置文件
content, err := os.ReadFile(selinuxConfig)
if err != nil {
// 如果文件不存在,创建默认配置
if os.IsNotExist(err) {
defaultConfig := fmt.Sprintf(`# This file controls the state of SELinux on the system.
# SELINUX= can take one of these three values:
# enforcing - SELinux security policy is enforced.
# permissive - SELinux prints warnings instead of enforcing.
# disabled - No SELinux policy is loaded.
SELINUX=%s
# SELINUXTYPE= can take one of three two values:
# targeted - Targeted processes are protected,
# minimum - Modification of targeted policy. Only selected processes are protected.
# mls - Multi Level Security protection.
SELINUXTYPE=targeted
`, mode)
return os.WriteFile(selinuxConfig, []byte(defaultConfig), 0644)
}
return err
}
// 替换 SELINUX= 行
lines := strings.Split(string(content), "\n")
for i, line := range lines {
if strings.HasPrefix(line, "SELINUX=") {
lines[i] = fmt.Sprintf("SELINUX=%s", mode)
break
}
}
return os.WriteFile(selinuxConfig, []byte(strings.Join(lines, "\n")), 0644)
}
// GetSELinuxMode 获取当前 SELinux 模式
func GetSELinuxMode() (string, error) {
cmd := exec.Command("getenforce")
output, err := cmd.Output()
if err != nil {
return "", err
}
return strings.ToLower(strings.TrimSpace(string(output))), nil
}

187
internal/system/ssh.go Normal file
View File

@@ -0,0 +1,187 @@
package system
import (
"fmt"
"os"
"os/exec"
"strings"
"sunhpc/internal/config"
)
// ConfigureSSH 配置 SSH 服务
// 参数: cfg - config.SSHConfig 结构体
// 返回: error - 配置错误
func ConfigureSSH(cfg config.SSHConfig) error {
const sshdConfig = "/etc/ssh/sshd_config"
// 读取现有配置
content, err := os.ReadFile(sshdConfig)
if err != nil {
return fmt.Errorf("读取 sshd_config 失败: %v", err)
}
// 备份原始配置
backupPath := sshdConfig + ".sunhpc.bak"
if _, err := os.Stat(backupPath); os.IsNotExist(err) {
if err := os.WriteFile(backupPath, content, 0644); err != nil {
fmt.Printf("警告: 无法创建备份文件 %s: %v\n", backupPath, err)
}
}
// 解析和修改配置
lines := strings.Split(string(content), "\n")
newLines := make([]string, 0, len(lines))
configMap := make(map[string]bool)
// 处理每一行
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
newLines = append(newLines, line)
continue
}
parts := strings.Fields(trimmed)
if len(parts) < 2 {
newLines = append(newLines, line)
continue
}
key := parts[0]
configMap[key] = true
// 根据配置更新
switch key {
case "PermitRootLogin":
if cfg.PermitRootLogin != "" {
newLines = append(newLines, fmt.Sprintf("PermitRootLogin %s", cfg.PermitRootLogin))
} else {
newLines = append(newLines, line)
}
case "PasswordAuthentication":
if cfg.PasswordAuth != "" {
newLines = append(newLines, fmt.Sprintf("PasswordAuthentication %s", cfg.PasswordAuth))
} else {
newLines = append(newLines, line)
}
default:
newLines = append(newLines, line)
}
}
// 添加缺失的配置项
if cfg.PermitRootLogin != "" && !configMap["PermitRootLogin"] {
newLines = append(newLines, fmt.Sprintf("PermitRootLogin %s", cfg.PermitRootLogin))
}
if cfg.PasswordAuth != "" && !configMap["PasswordAuthentication"] {
newLines = append(newLines, fmt.Sprintf("PasswordAuthentication %s", cfg.PasswordAuth))
}
// 写入新配置
newContent := strings.Join(newLines, "\n")
if err := os.WriteFile(sshdConfig, []byte(newContent), 0644); err != nil {
return fmt.Errorf("写入 sshd_config 失败: %v", err)
}
// 测试配置语法
if err := testSSHDConfig(); err != nil {
// 恢复备份
if backup, err := os.ReadFile(backupPath); err == nil {
os.WriteFile(sshdConfig, backup, 0644)
}
return fmt.Errorf("SSH 配置语法错误: %v", err)
}
// 重启 SSH 服务
return restartSSHD()
}
// testSSHDConfig 测试 sshd 配置语法
func testSSHDConfig() error {
cmd := exec.Command("sshd", "-t")
cmd.Stderr = os.Stderr
return cmd.Run()
}
// restartSSHD 重启 SSH 服务
func restartSSHD() error {
// 尝试不同的服务管理器
serviceMgrs := []struct {
name string
args []string
}{
{"systemctl", []string{"restart", "sshd"}},
{"systemctl", []string{"restart", "ssh"}},
{"service", []string{"sshd", "restart"}},
{"service", []string{"ssh", "restart"}},
}
for _, mgr := range serviceMgrs {
if _, err := exec.LookPath(mgr.name); err == nil {
cmd := exec.Command(mgr.name, mgr.args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err == nil {
return nil
}
}
}
return fmt.Errorf("无法重启 SSH 服务,请手动重启")
}
// AddSSHKey 添加 SSH 公钥到指定用户
func AddSSHKey(username, pubkey string) error {
// 获取用户主目录
homeDir, err := getUserHomeDir(username)
if err != nil {
return err
}
sshDir := homeDir + "/.ssh"
authKeys := sshDir + "/authorized_keys"
// 创建 .ssh 目录
if err := os.MkdirAll(sshDir, 0700); err != nil {
return err
}
// 追加公钥
f, err := os.OpenFile(authKeys, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return err
}
defer f.Close()
_, err = f.WriteString(pubkey + "\n")
if err != nil {
return err
}
// 修改所有权
return chownRecursive(sshDir, username)
}
// getUserHomeDir 获取用户主目录
func getUserHomeDir(username string) (string, error) {
cmd := exec.Command("getent", "passwd", username)
output, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("用户 %s 不存在", username)
}
parts := strings.Split(strings.TrimSpace(string(output)), ":")
if len(parts) >= 6 {
return parts[5], nil
}
return "", fmt.Errorf("无法获取用户 %s 的主目录", username)
}
// chownRecursive 递归修改文件所有者
func chownRecursive(path, username string) error {
cmd := exec.Command("chown", "-R", username+":"+username, path)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}

87
internal/system/sysctl.go Normal file
View File

@@ -0,0 +1,87 @@
package system
import (
"bufio"
"fmt"
"os"
"os/exec"
"strings"
"time"
)
// ConfigureSysctl 设置内核参数
// 参数: params - 键值对映射,如 {"net.ipv4.ip_forward": "1"}
// 返回: error - 第一个失败的错误
func ConfigureSysctl(params map[string]string) error {
if len(params) == 0 {
return nil
}
// 首先应用临时配置
for k, v := range params {
cmd := exec.Command("sysctl", "-w", fmt.Sprintf("%s=%s", k, v))
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("设置 sysctl %s=%s 失败: %v", k, v, err)
}
}
// 持久化配置到 /etc/sysctl.conf 或 /etc/sysctl.d/
return appendToSysctlConf(params)
}
// appendToSysctlConf 将参数写入持久化配置文件
func appendToSysctlConf(params map[string]string) error {
const sysctlFile = "/etc/sysctl.d/99-sunhpc.conf"
// 读取现有配置
existing := make(map[string]bool)
if data, err := os.ReadFile(sysctlFile); err == nil {
scanner := bufio.NewScanner(strings.NewReader(string(data)))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
existing[strings.TrimSpace(parts[0])] = true
}
}
}
// 构建新内容
var content strings.Builder
content.WriteString("# SunHPC 系统优化配置\n")
content.WriteString("# 生成时间: " + time.Now().Format(time.RFC3339) + "\n\n")
for k, v := range params {
// 跳过已存在的配置(避免重复)
if existing[k] {
continue
}
content.WriteString(fmt.Sprintf("%s = %s\n", k, v))
}
// 如果有新配置才写入
if content.Len() > 100 {
if err := os.WriteFile(sysctlFile, []byte(content.String()), 0644); err != nil {
return err
}
// 应用持久化配置
return exec.Command("sysctl", "--system").Run()
}
return nil
}
// GetSysctl 获取当前内核参数值
func GetSysctl(key string) (string, error) {
cmd := exec.Command("sysctl", "-n", key)
output, err := cmd.Output()
if err != nil {
return "", err
}
return strings.TrimSpace(string(output)), nil
}

178
internal/system/system.go Normal file
View File

@@ -0,0 +1,178 @@
package system
import (
"os"
"path/filepath"
"sunhpc/internal/config"
"sunhpc/internal/log"
)
// Context 系统配置上下文,包含所有命令行参数
type Context struct {
Force bool // 强制模式
DryRun bool // 干运行模式
Verbose bool // 详细输出
Timeout int // 超时时间
Backup string // 备份路径
YesMode bool // 自动确认
}
// ApplyAll 应用所有系统配置
func ApplyAll(cfg *config.SunHPCConfig) error {
log.Info("开始应用系统配置...")
if err := SetHostnameWithContext(cfg.Hostname, nil); err != nil {
log.Warnf("设置主机名失败: %v", err)
}
if err := SetMOTDWithContext(cfg.MOTD, nil); err != nil {
log.Warnf("设置 MOTD 失败: %v", err)
}
if err := ConfigureSysctlWithContext(cfg.Sysctl, nil); err != nil {
log.Warnf("配置 sysctl 失败: %v", err)
}
if err := ConfigureSELinuxWithContext(cfg.SELinux, nil); err != nil {
log.Warnf("配置 SELinux 失败: %v", err)
}
if err := ConfigureSSHWithContext(cfg.SSH, nil); err != nil {
log.Warnf("配置 SSH 失败: %v", err)
}
log.Info("系统配置应用完成")
return nil
}
// SetHostnameWithContext 设置系统主机名,带上下文参数
func SetHostnameWithContext(hostname string, ctx *Context) error {
if ctx != nil && ctx.DryRun {
log.Infof("[干运行] 设置主机名为: %s", hostname)
return nil
}
if hostname == "" {
return nil
}
// 检查是否需要强制设置
current, _ := os.Hostname()
if current == hostname && (ctx == nil || !ctx.Force) {
log.Infof("主机名已是 '%s',跳过设置", hostname)
return nil
}
log.Infof("设置主机名为: %s", hostname)
return SetHostname(hostname)
}
// SetMOTDWithContext 设置 MOTD带上下文参数
func SetMOTDWithContext(content string, ctx *Context) error {
if ctx != nil && ctx.DryRun {
log.Info("[干运行] 设置 MOTD")
return nil
}
if content == "" {
return nil
}
// 备份现有文件
if ctx != nil && ctx.Backup != "" {
backupMOTD(ctx.Backup)
}
log.Info("更新 /etc/motd")
return SetMOTD(content)
}
// ConfigureSysctlWithContext 配置内核参数,带上下文参数
func ConfigureSysctlWithContext(params map[string]string, ctx *Context) error {
if ctx != nil && ctx.DryRun {
log.Info("[干运行] 配置 sysctl 参数")
return nil
}
if len(params) == 0 {
return nil
}
// 备份现有配置
if ctx != nil && ctx.Backup != "" {
backupSysctl(ctx.Backup)
}
return ConfigureSysctl(params)
}
// ConfigureSELinuxWithContext 配置 SELinux带上下文参数
func ConfigureSELinuxWithContext(mode string, ctx *Context) error {
if ctx != nil && ctx.DryRun {
log.Infof("[干运行] 设置 SELinux 模式为: %s", mode)
return nil
}
if mode == "" {
return nil
}
// 检查当前模式
current, _ := GetSELinuxMode()
if current == mode && (ctx == nil || !ctx.Force) {
log.Infof("SELinux 已是 '%s' 模式,跳过设置", mode)
return nil
}
log.Infof("设置 SELinux 模式为: %s", mode)
return ConfigureSELinux(mode)
}
// ConfigureSSHWithContext 配置 SSH带上下文参数
func ConfigureSSHWithContext(cfg config.SSHConfig, ctx *Context) error {
if ctx != nil && ctx.DryRun {
log.Info("[干运行] 配置 SSH 服务")
return nil
}
// 备份配置文件
if ctx != nil && ctx.Backup != "" {
backupSSHConfig(ctx.Backup)
}
log.Info("配置 SSH 服务")
return ConfigureSSH(cfg)
}
// 备份函数
func backupMOTD(backupDir string) error {
backupPath := filepath.Join(backupDir, "motd."+filepath.Base(os.Args[0])+".bak")
if err := os.MkdirAll(backupDir, 0755); err != nil {
return err
}
return copyFile("/etc/motd", backupPath)
}
func backupSysctl(backupDir string) error {
backupPath := filepath.Join(backupDir, "sysctl.conf.bak")
if err := os.MkdirAll(backupDir, 0755); err != nil {
return err
}
return copyFile("/etc/sysctl.conf", backupPath)
}
func backupSSHConfig(backupDir string) error {
backupPath := filepath.Join(backupDir, "sshd_config.bak")
if err := os.MkdirAll(backupDir, 0755); err != nil {
return err
}
return copyFile("/etc/ssh/sshd_config", backupPath)
}
func copyFile(src, dst string) error {
data, err := os.ReadFile(src)
if err != nil {
return err
}
return os.WriteFile(dst, data, 0644)
}

View File

@@ -0,0 +1,61 @@
package template
import (
"bytes"
"fmt"
"os"
"os/exec"
"path/filepath"
"text/template"
"sunhpc/internal/config"
"sunhpc/internal/log"
)
// RenderAndExecute 从模板目录加载模板,渲染后生成临时脚本并执行
// tmplName: 模板文件名(位于 /etc/sunhpc/tmpl.d/
// data: 模板变量
func RenderAndExecute(tmplName string, data interface{}) error {
tmplPath := filepath.Join(config.TmplDir, tmplName)
if _, err := os.Stat(tmplPath); err != nil {
return fmt.Errorf("模板文件不存在: %s", tmplPath)
}
content, err := os.ReadFile(tmplPath)
if err != nil {
return err
}
t, err := template.New(tmplName).Parse(string(content))
if err != nil {
return err
}
var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil {
return err
}
// 生成临时脚本
tmpFile, err := os.CreateTemp("/tmp", "sunhpc-*.sh")
if err != nil {
return err
}
defer os.Remove(tmpFile.Name())
if _, err := tmpFile.Write(buf.Bytes()); err != nil {
tmpFile.Close()
return err
}
tmpFile.Close()
if err := os.Chmod(tmpFile.Name(), 0755); err != nil {
return err
}
log.Infof("执行模板脚本: %s", tmpFile.Name())
cmd := exec.Command("/bin/bash", tmpFile.Name())
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}