package models import ( "database/sql/driver" "encoding/json" "fmt" "io/ioutil" "log" "main/configs" "net/http" "path/filepath" "time" "gopkg.in/yaml.v2" "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/errors" "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312" "github.com/zhshch2002/goreq" ) type ModelList []int func (list *ModelList) Scan(value interface{}) error { return json.Unmarshal(value.([]byte), list) } func (list ModelList) Value() (driver.Value, error) { return json.Marshal(list) } type Server struct { ID string `json:"id" gorm:"primary_key"` Name string `json:"name"` Type string `json:"type"` // (训练|推理) IP string `json:"ip"` // 服务器IP Port int `json:"port"` // 7860 Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中) UserName string `json:"username"` // 用户名 Password string `json:"password"` // 用户密码 Models ModelList `json:"models"` // 服务器中所有模型 ModelID int `json:"model_id"` // 当前加载的模型 CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } var config = struct { TencentCloud struct { SecretId string `yaml:"SecretId"` SecretKey string `yaml:"SecretKey"` Region string `yaml:"Region"` } `yaml:"TencentCloud"` }{} func init() { configs.ORMDB().AutoMigrate(&Server{}) // 檢查所有服務器的狀態, 無效的服務器設置為異常 var servers []Server configs.ORMDB().Find(&servers) for _, server := range servers { server.CheckStatus() } // 讀取配置文件 absPath, _ := filepath.Abs("./data/config.yaml") configFile, err := ioutil.ReadFile(absPath) if err != nil { panic(fmt.Errorf("讀取配置文件失敗: %v", err)) } if err := yaml.Unmarshal(configFile, &config); err != nil { panic(fmt.Errorf("格式化配置文件失敗: %v", err)) } // 初始化检查默认服务器 if err := InitDefaultServer(); err != nil { panic(fmt.Errorf("初始化默认服务器失败: %v", err)) } } // 检查默认服务器是否存在, 不存在则添加 func InitDefaultServer() (err error) { var server Server if err = configs.ORMDB().Where("id = ?", "default").First(&server).Error; err != nil { server = Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"} if err = configs.ORMDB().Create(&server).Error; err != nil { return fmt.Errorf("创建默认服务器失败: %v", err) } } // 初始化服务器中的模型列表 if err = server.InitModels(); err != nil { return fmt.Errorf("初始化服务器中的模型列表失败: %v", err) } return } // 初始化服务器中的模型列表 func (server *Server) InitModels() (err error) { // 刷新检查点 /sdapi/v1/refresh-checkpoint resp_update, err := http.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port), "application/json", nil) if err != nil { return fmt.Errorf("更新服务器中的模型列表失败: %v", err) } defer resp_update.Body.Close() log.Println("更新服务器中的ckpt模型列表:", resp_update.Status) // 刷新检查点 /sdapi/v1/refresh-loras lora_update, err := http.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-loras", server.IP, server.Port), "application/json", nil) if err != nil { return fmt.Errorf("更新服务器中的模型列表失败: %v", err) } defer lora_update.Body.Close() log.Println("更新服务器中的lora模型列表:", lora_update.Status) // 获取服务器中的模型列表(Lora) resp_lora, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/loras", server.IP, server.Port)) if err != nil { return fmt.Errorf("获取服务器中的模型列表失败: %v", err) } defer resp_lora.Body.Close() // 解码JSON (数组) var data_lora []map[string]interface{} if err := json.NewDecoder(resp_lora.Body).Decode(&data_lora); err != nil { return err } //// 加载所有lora模型列表取得id和hash和path //var lora_models []map[string]interface{} //if err := configs.ORMDB().Table("models").Where("type = ?", "lora").Select("id", "name", "hash", "model_path").Find(&lora_models).Error; err != nil { // return fmt.Errorf("获取模型列表失败: %v", err) //} //// 加载所有基础模型列表取得id和hash和path //var base_models []map[string]interface{} //if err := configs.ORMDB().Table("models").Where("type = ?", "ckp").Select("id", "name", "hash", "model_path").Find(&base_models).Error; err != nil { // return fmt.Errorf("获取模型列表失败: %v", err) //} //判断模型是否存在 := func(list []map[string]interface{}, item map[string]interface{}) bool { // for _, item2 := range list { // fmt.Println(item2) // return true // //if item["hash"].(string) == item2["hash"].(string) { // // return true // //} // } // return false //} //// 遍历数据库中的模型列表, 如果不存在则添加 //for _, item := range lora_models { // fmt.Println("模型:", item["id"], item["hash"], item["name"]) // if !判断模型是否存在(data_lora, item) { // // 从数据库删除不存在的模型 // //if err := configs.ORMDB().Delete(&Model{}, item["id"]).Error; err != nil { // // return fmt.Errorf("删除模型失败: %v", err) // //} // fmt.Println("模型:", item["id"], item["hash"], item["name"], "\033[31mfail\033[0m") // } // fmt.Println("模型:", item["id"], item["hash"], item["name"], "\033[32mok\033[0m") //} //for _, item := range base_models { // fmt.Println("模型:", item["id"], item["hash"], item["name"]) //} //for _, item := range data_lora { // for _, model := range base_models { // if item["hash"].(string) == model["hash"].(string) { // } // } //} // 从数据库检查此模型hash是否存在 for _, item := range data_lora { var model Model if err := configs.ORMDB().Where("hash = ?", item["path"].(string)).First(&model).Error; err != nil { // 不存在则添加 model = Model{ Name: item["name"].(string), Hash: item["path"].(string), ModelCheckpoint: item["alias"].(string), ModelPath: item["path"].(string), ServerID: server.ID, Type: "lora", } // 添加到数据库 if err := configs.ORMDB().Create(&model).Error; err != nil { return fmt.Errorf("添加模型到数据库失败: %v", err) } // 添加到模型列表 server.Models = append(server.Models, model.ID) } } // 更新数据库 if err := configs.ORMDB().Save(&server).Error; err != nil { return fmt.Errorf("更新数据库失败: %v", err) } // 刷新检查点 if err = goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port)).Do().Err; err != nil { return fmt.Errorf("刷新检查点失败: %v", err) } // 获取服务器中的模型列表(ckp) dexs := []struct { Name string `json:"model_name"` Hash string `json:"sha256"` ModelCheckpoint string `json:"title"` ModelPath string `json:"filename"` }{} // 获取服务器中的模型列表(ckp) if err = goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil { return fmt.Errorf("获取服务器中的模型列表失败: %v", err) } for _, item := range dexs { // 如果hash为空, 则逐一加载这些模型使其生成hash if item.Hash == "" { fmt.Println("加载模型:", item.ModelCheckpoint) if err := goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).SetJsonBody(map[string]interface{}{ "sd_model_checkpoint": item.ModelCheckpoint, "CLIP_stop_at_last_layers": 2, }).Do().Error(); err != nil { return fmt.Errorf("加载模型失败: %v", err) } } } // 重新获取服务器中的模型列表(ckp), 忽略hash为空的模型 if err = goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil { return fmt.Errorf("获取服务器中的模型列表失败: %v", err) } for _, item := range dexs { if item.Hash == "" { fmt.Println("忽略模型:", item.ModelCheckpoint) continue } var model Model if err := configs.ORMDB().Where("hash = ?", item.Hash).First(&model).Error; err != nil { // 不存在则添加 model = Model{ Name: item.Name, Hash: item.Hash, ModelCheckpoint: item.ModelCheckpoint, ModelPath: item.ModelPath, ServerID: server.ID, Type: "ckp", } // 添加到数据库 if err := configs.ORMDB().Create(&model).Error; err != nil { return fmt.Errorf("添加模型到数据库失败: %v", err) } // 添加到模型列表 server.Models = append(server.Models, model.ID) } } // 更新数据库 if err := configs.ORMDB().Save(&server).Error; err != nil { return fmt.Errorf("更新数据库失败: %v", err) } // 检查服务器当前加载的模型 if server.ModelID == 0 { var form = struct { SdCheckpointHash string `json:"sd_checkpoint_hash"` SdModelCheckpoint string `json:"sd_model_checkpoint"` }{} // 检查当前是否为目标模型, 不是则执行切换模型 if err := goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).Do().BindJSON(&form); err != nil { return fmt.Errorf("获取推理机配置失败: %v", err) } fmt.Println("当前模型:", form.SdModelCheckpoint) var model Model if err := configs.ORMDB().Where("model_checkpoint = ?", form.SdModelCheckpoint).First(&model).Error; err != nil { return fmt.Errorf("获取模型信息失败: %v", err) } server.ModelID = model.ID if err := configs.ORMDB().Save(&server).Error; err != nil { return fmt.Errorf("更新数据库失败: %v", err) } } return } // 创建一台新服务器 func NewServer(server_type string) (server Server, err error) { // 调用 API 创建一台新服务器(通過腾讯云API創建服務器) client, err := cvm.NewClient(common.NewCredential(config.TencentCloud.SecretId, config.TencentCloud.SecretKey), config.TencentCloud.Region, profile.NewClientProfile()) if err != nil { return server, fmt.Errorf("初始化騰訊雲SDK客戶端失敗: %v", err) } // 实例化一个请求对象, 指定啓動模板, 以創建指定規格的服務器 request := cvm.NewRunInstancesRequest() request.LaunchTemplate = &cvm.LaunchTemplate{LaunchTemplateId: common.StringPtr("lt-ks6y5evh")} response, err := client.RunInstances(request) if _, ok := err.(*errors.TencentCloudSDKError); ok { return server, fmt.Errorf("已返回 API 错误: %v", err) } if err != nil { return server, fmt.Errorf("运行实例失败: %v", err) } fmt.Println("創建服務器成功:", response.Response.InstanceIdSet[0]) // 获取服务器信息 var get_server_info = func(InstanceIdSet *string) (server Server, err error) { response2, err := client.DescribeInstances(cvm.NewDescribeInstancesRequest()) if err != nil { return server, fmt.Errorf("獲取實例詳情失敗: %v", err) } for _, instance := range response2.Response.InstanceSet { if *instance.InstanceId != *InstanceIdSet { server.ID = *instance.InstanceId server.Name = *instance.InstanceName server.IP = *instance.PublicIpAddresses[0] server.Port = 7890 server.Status = *instance.InstanceState configs.ORMDB().Create(&server) return server, nil } } return server, fmt.Errorf("未取得實例詳情: %v", err) } // 等待服务器创建完成 return get_server_info(response.Response.InstanceIdSet[0]) } // 注销服务器 func (server *Server) Delete() error { client, err := cvm.NewClient(common.NewCredential(config.TencentCloud.SecretId, config.TencentCloud.SecretKey), config.TencentCloud.Region, profile.NewClientProfile()) if err != nil { return fmt.Errorf("初始化騰訊雲SDK客戶端失敗: %v", err) } request := cvm.NewTerminateInstancesRequest() request.InstanceIds = []*string{common.StringPtr(server.ID)} response, err := client.TerminateInstances(request) if _, ok := err.(*errors.TencentCloudSDKError); ok { return fmt.Errorf("已返回 API 错误: %v", err) } if err != nil { return fmt.Errorf("註銷實例失敗: %v", err) } // 從列表中刪除服務器 configs.ORMDB().Delete(&server) fmt.Println("註銷服務器成功:", server.ID, response.Response) return nil } // 檢查服務器是否正常 func (server *Server) CheckStatus() error { switch server.Type { case "训练": resp, err := http.Get(fmt.Sprintf("http://%s:%d/dreambooth/status", server.IP, server.Port)) if err != nil { server.Status = "異常" return err } defer resp.Body.Close() // 解碼JSON var data map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { return err } // 解碼JSON var current_state map[string]interface{} if err := json.Unmarshal([]byte(data["current_state"].(string)), ¤t_state); err != nil { return err } //log.Println("current_state:", current_state) // 檢查服務器是否正常 if !current_state["active"].(bool) { server.Status = "異常" return fmt.Errorf("服務器狀態異常: active=false") } server.Status = "正常" case "推理": server.Status = "就绪" default: server.Status = "異常" } // 檢查服務器是否正常 return nil }