package models import ( "bytes" "crypto/md5" "fmt" "io/ioutil" "log" "main/configs" "net/http" "net/url" "os" "path/filepath" "time" "encoding/base64" "image/png" "github.com/chai2010/webp" "github.com/zhshch2002/goreq" ) type Model struct { ID int `json:"id" gorm:"primary_key"` // 模型ID Name string `json:"name"` // 模型名稱 Info string `json:"info"` // 模型描述 Type string `json:"type"` // 模型類型(lora|ckp|hyper|ti) TriggerWords string `json:"trigger_words"` // 觸發詞 BaseModel string `json:"base_model"` // 基礎模型(SD1.5|SD2) ModelPath string `json:"model_path"` // 模型路徑(實際存放在服務器上的路徑) Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error) Progress int `json:"progress"` // (0-100) Image string `json:"image"` // 封面圖片實際地址 Hash string `json:"hash"` // 模型哈希值 Epochs int `json:"epochs"` // 訓練步數 LearningRate float32 `json:"learning_rate"` // 學習率(0.000005) Tags TagList `json:"tags"` // 模型標籤(標籤名數組) UserID int `json:"user_id"` // 模型的所有者 DatasetID int `json:"dataset_id"` // 模型所使用的數據集ID ServerID string `json:"server_id"` // 模型所在服務器(訓練機或推理機) Stars StarList `json:"stars"` // 模型的收藏者 CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } // 创建一个带缓冲的通道,缓冲区大小为 10 // var ch = make(chan int, 10) func init() { configs.ORMDB().AutoMigrate(&Model{}) // 检查 images 目录是否存在, 不存在则创建 if _, err := os.Stat("data/images"); err != nil { if err := os.MkdirAll("data/images", 0777); err != nil { log.Println(err) } } // 处理推理任务 //go func() { // for { // // 从通道中取出一个数据 // model := <-ch // // 模型状态变化时, 向监听此模型的所有连接发送消息 // } //}() } func (model *Model) Inference(image_list []Image, callback func(Image)) { // 模型未部署到推理機 if model.ServerID == "" { //log.Println("模型未部署到推理機, 开始部署模型") log.Println("模型已部署到推理機, 开始推理模型") var server Server server.IP = "106.15.192.42" server.Port = 7860 //if err := configs.ORMDB().Where("models LIKE ?", "%"+model.Name+"%").Take(&server).Error; err != nil { // log.Println(err) // // 如果没有则寻找空闲服务器 // // 如果没有空闲则创建新服务器 // // 取一台空闲的推理机上传并切换到此模型 // // 新建一台推理机上传并切换到此模型 //} // 执行生成任务 if model.Image == "" { // 发送的参数 var data = struct { //EnableHr bool `json:"enable_hr"` //DenoisingStrength int `json:"denoising_strength"` //FirstphaseWidth int `json:"firstphase_width"` //FirstphaseHeight int `json:"firstphase_height"` //HrScale int `json:"hr_scale"` //HrUpscaler string `json:"hr_upscaler"` //HrSecondPassSteps int `json:"hr_second_pass_steps"` //HrResizeX int `json:"hr_resize_x"` //HrResizeY int `json:"hr_resize_y"` //HrSamplerName string `json:"hr_sampler_name"` //HrPrompt string `json:"hr_prompt"` //HrNegativePrompt string `json:"hr_negative_prompt"` Prompt string `json:"prompt"` //Styles []string `json:"styles"` //Seed int `json:"seed"` //Subseed int `json:"subseed"` //SubseedStrength int `json:"subseed_strength"` //SeedResizeFromH int `json:"seed_resize_from_h"` //SeedResizeFromW int `json:"seed_resize_from_w"` //SamplerName string `json:"sampler_name"` //BatchSize int `json:"batch_size"` NIter int `json:"n_iter"` Steps int `json:"steps"` CfgScale int `json:"cfg_scale"` //Width int `json:"width"` //Height int `json:"height"` //RestoreFaces bool `json:"restore_faces"` //Tiling bool `json:"tiling"` //DoNotSaveSamples bool `json:"do_not_save_samples"` //DoNotSaveGrid bool `json:"do_not_save_grid"` //NegativePrompt string `json:"negative_prompt"` //Eta int `json:"eta"` //SMinUncond int `json:"s_min_uncond"` //SChurn int `json:"s_churn"` //STmax int `json:"s_tmax"` //STmin int `json:"s_tmin"` //SNoise int `json:"s_noise"` //OverrideSettings map[string]string `json:"override_settings"` //OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"` //ScriptArgs []interface{} `json:"script_args"` //SamplerIndex string `json:"sampler_index"` //ScriptName string `json:"script_name"` //SendImages bool `json:"send_images"` //SaveImages bool `json:"save_images"` //AlwaysonScripts map[string]string `json:"alwayson_scripts"` }{ //EnableHr: false, //DenoisingStrength: 0, //FirstphaseWidth: 0, //FirstphaseHeight: 0, //HrScale: 2, //HrUpscaler: "nearest", //HrSecondPassSteps: 0, //HrResizeX: 0, //HrResizeY: 0, //HrSamplerName: "", //HrPrompt: "", //HrNegativePrompt: "", Prompt: image_list[0].Prompt, //Styles: []string{}, //Seed: -1, //Subseed: -1, //SubseedStrength: 0, //SeedResizeFromH: -1, //SeedResizeFromW: -1, //SamplerName: "beamsearch", //BatchSize: 1, NIter: len(image_list), // 1~100 Steps: 50, // 1~150 CfgScale: image_list[0].CfgScale, //Width: 512, //Height: 512, //RestoreFaces: false, //Tiling: false, //DoNotSaveSamples: false, //DoNotSaveGrid: false, //NegativePrompt: "", //Eta: 0, //SMinUncond: 0, //SChurn: 0, //STmax: 0, //STmin: 0, //SNoise: 1, //OverrideSettings: map[string]string{}, //OverrideSettingsRestoreAfterwards: false, //ScriptArgs: []interface{}{}, //SamplerIndex: "Euler", //ScriptName: "generate", //SendImages: true, //SaveImages: false, //AlwaysonScripts: map[string]string{}, } // 接收到的图片列表 var rest = struct { Images []string `json:"images"` }{} var url = fmt.Sprintf("http://%s:%d/sdapi/v1/txt2img", server.IP, server.Port) if err := goreq.Post(url).SetJsonBody(data).Do().BindJSON(&rest); err != nil { log.Println("API 查询失败:", err) } for index, img := range rest.Images { var filename = fmt.Sprintf("%x", md5.Sum([]byte(img))) log.Println("保存图片:", filename) if err := SaveBase64Image(img, "data/images/"+filename+".webp"); err != nil { log.Println(err) } image_list[index].Name = filename image_list[index].Path = "data/images/" + filename + ".webp" image_list[index].Hash = filename image_list[index].Type = "image/webp" image_list[index].Width = 512 image_list[index].Height = 512 image_list[index].Format = "webp" image_list[index].Status = "success" image_list[index].Progress = 100 callback(image_list[index]) } } return } log.Println("模型未部署到推理機, 取消推理模型") } // 将base64编码的图片保存到本地webp func SaveBase64Image(base64Str string, filename string) error { // 解码base64图片 data, err := base64.StdEncoding.DecodeString(base64Str) if err != nil { return err } // 将png图片解码为image.Image img, err := png.Decode(bytes.NewReader(data)) if err != nil { return err } // 创建webp编码器 webpWriter, err := os.Create(filename) if err != nil { return err } defer webpWriter.Close() // 将image.Image编码为webp格式并保存到本地 if err := webp.Encode(webpWriter, img, &webp.Options{Lossless: true}); err != nil { return err } return nil } func (model *Model) Train() (err error) { // 獲取一臺空閒的訓練機 var server Server if err = configs.ORMDB().Where("status = ?", "正常").First(&server).Error; err != nil { fmt.Println(err) // TOOD: 沒有空閒的訓練機, 訓練排隊, 等待訓練機空閒 // TODO: 如果訓練機數量低於10臺, 則創建新的訓練機 return } // 獲取數據集 var dataset Dataset = Dataset{ID: model.DatasetID} if err = configs.ORMDB().First(&dataset).Error; err != nil { fmt.Println(err) return } // 更新模型狀態 model.ServerID = server.ID model.Status = "training" if err = configs.ORMDB().Save(&model).Error; err != nil { fmt.Println(err) return } // 創建數據集目錄 dirPath := filepath.Join("data/datasets", fmt.Sprint(dataset.ID), "images") if err := os.MkdirAll(dirPath, 0755); err != nil { fmt.Println(err) return err } // 將數據下載到本地 for index, url := range dataset.Images { fmt.Println("下載數據到本地:", index, url) // 檢查文件是否已經存在 filename := fmt.Sprintf("%x", md5.Sum([]byte(url))) filePath := filepath.Join(dirPath, filename) if _, err := os.Stat(filePath); err == nil { fmt.Println("文件已經存在:", filePath) continue } // 下載到臨時目錄 resp, err := http.Get(url) if err != nil { fmt.Println("下載失敗:", err) continue } defer resp.Body.Close() data, err := ioutil.ReadAll(resp.Body) if err != nil { fmt.Println("保存失敗:", err) continue } // 保存文件到本地目錄下 if err := ioutil.WriteFile(filePath, data, 0644); err != nil { fmt.Println(err) continue } } fmt.Println("數據下載完成") // 檢查目錄下是否有文件, 如果沒有文件則返回錯誤 files, err := ioutil.ReadDir(dirPath) if err != nil || len(files) == 0 { fmt.Println("目錄下沒有文件") return fmt.Errorf("目錄下沒有文件") } // 將數據上傳到訓練機 // 按類型執行訓練任務 if model.Type == "dreambooth" { // 創建數據庫模型 fmt.Println("創建數據庫模型 ======================================") resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/dreambooth/createModel?new_model_name=%s&new_model_src=%s", server.IP, server.Port, model.Name, model.ModelPath), nil) if err != nil { fmt.Println("創建訓練任務失敗:", err.Error()) return err } defer resp.Body.Close() // 打印返回的結果 body, err := ioutil.ReadAll(resp.Body) if err != nil { fmt.Println("解碼任務數據失敗:", err) return err } fmt.Println("預覽:", string(body)) // 上傳數據到訓練機 // 執行訓練命令 } if model.Type == "lora" { // 創建數據庫模型 formData := url.Values{} formData.Set("name", model.Name) formData.Set("type", model.Type) resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/lora/createModel", server.IP, server.Port), formData) if err != nil { fmt.Println(err) return err } defer resp.Body.Close() // 上傳數據到訓練機 // 執行訓練命令 } //// 將文件全部上傳到訓練機, 使用scp命令,自動使用密碼登錄 //err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s@%s:~/dataset_%d", server.UserName, server.IP, model.ID)).Run() //if err != nil { // fmt.Println(err) // return err //} //// 刪除本地臨時目錄 //if err := os.RemoveAll(dirPath); err != nil { // fmt.Println(err) // return err //} //// 将基础模型上传到训练机(使用scp命令) //baseModelPath := filepath.Join("data/models", model.BaseModel) //fmt.Println("baseModelPath:", baseModelPath) //err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run() //if err != nil { // fmt.Println(err) // return err //} //// 進行訓練(訓練機上調用訓練webapi接口:參數) //resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil) //if err != nil { // fmt.Println(err) // return err //} //defer resp.Body.Close() //// 循環監聽訓練進度 //for i := 0; i < 5; i++ { // // 訓練機上調用訓練webapi接口:獲取訓練進度 // resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP)) // if err != nil { // fmt.Println(err) // return err // } // defer resp.Body.Close() //// 更新本地訓練進度 // var progress int // if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil { // fmt.Println(err) // return err // } //} // // TODO: 訓練完成後將模型下載到本地 return nil }