package models import ( "crypto/md5" "fmt" "io/ioutil" "main/configs" "net/http" "net/url" "os" "path/filepath" "time" ) type Model struct { ID int `json:"id" gorm:"primary_key"` Name string `json:"name"` Type string `json:"type"` // (lora|ckp|hyper|ti) TriggerWords string `json:"trigger_words"` // 觸發詞 BaseModel string `json:"base_model"` // (SD1.5|SD2) ModelPath string `json:"model_path"` // 模型路徑 Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error) Progress int `json:"progress"` // (0-100) Image string `json:"image"` // 封面圖片實際地址 Hash string `json:"hash"` // 模型哈希值 Epochs int `json:"epochs"` // 訓練步數 Tags TagList `json:"tags"` UserID int `json:"user_id"` DatasetID int `json:"dataset_id"` ServerID int `json:"server_id"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } func init() { configs.ORMDB().AutoMigrate(&Model{}) } func (model *Model) Train() (err error) { // 獲取一臺空閒的訓練機 var server Server if err = configs.ORMDB().Where("status = ?", "閒置").First(&server).Error; err != nil { fmt.Println(err) // TOOD: 沒有空閒的訓練機, 訓練排隊, 等待訓練機空閒 // TODO: 如果訓練機數量低於10臺, 則創建新的訓練機 return } // 獲取數據集 var dataset Dataset = Dataset{ID: model.DatasetID} if err = configs.ORMDB().First(&dataset).Error; err != nil { fmt.Println(err) return } // 更新模型狀態 model.ServerID = server.ID model.Status = "training" if err = configs.ORMDB().Save(&model).Error; err != nil { fmt.Println(err) return } // 創建數據集目錄 dirPath := filepath.Join("data/datasets", fmt.Sprint(dataset.ID), "images") if err := os.MkdirAll(dirPath, 0755); err != nil { fmt.Println(err) return err } // 將數據下載到本地 for index, url := range dataset.Images { fmt.Println("下載數據到本地:", index, url) // 下載到臨時目錄 resp, err := http.Get(url) if err != nil { fmt.Println("下載失敗:", err) continue } defer resp.Body.Close() data, err := ioutil.ReadAll(resp.Body) if err != nil { fmt.Println("保存失敗:", err) continue } // 保存文件到本地目錄下(自動創建目錄,文件名為url的md5值) filename := fmt.Sprintf("%x", md5.Sum([]byte(url))) filePath := filepath.Join(dirPath, filename) if err := os.MkdirAll(dirPath, 0755); err != nil { fmt.Println(err) continue } if err := ioutil.WriteFile(filePath, data, 0644); err != nil { fmt.Println(err) continue } } fmt.Println("數據下載完成") // 檢查目錄下是否有文件, 如果沒有文件則返回錯誤 files, err := ioutil.ReadDir(dirPath) if err != nil { fmt.Println(err) return err } if len(files) == 0 { fmt.Println("目錄下沒有文件") return fmt.Errorf("目錄下沒有文件") } // 按類型執行訓練任務 if model.Type == "dreambooth" { // 創建數據庫模型 fmt.Println("創建數據庫模型 ======================================") resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/dreambooth/createModel?new_model_name=%s&new_model_src=%s", server.IP, server.Port, model.Name, model.ModelPath), nil) if err != nil { fmt.Println("創建訓練任務失敗:", err.Error()) return err } defer resp.Body.Close() // 打印返回的結果 body, err := ioutil.ReadAll(resp.Body) if err != nil { fmt.Println("解碼任務數據失敗:", err) return err } fmt.Println("預覽:", string(body)) // 上傳數據到訓練機 // 執行訓練命令 } if model.Type == "lora" { // 創建數據庫模型 formData := url.Values{} formData.Set("name", model.Name) formData.Set("type", model.Type) resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/lora/createModel", server.IP, server.Port), formData) if err != nil { fmt.Println(err) return err } defer resp.Body.Close() // 上傳數據到訓練機 // 執行訓練命令 } //// 將文件全部上傳到訓練機, 使用scp命令,自動使用密碼登錄 //err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s@%s:~/dataset_%d", server.UserName, server.IP, model.ID)).Run() //if err != nil { // fmt.Println(err) // return err //} //// 刪除本地臨時目錄 //if err := os.RemoveAll(dirPath); err != nil { // fmt.Println(err) // return err //} //// 将基础模型上传到训练机(使用scp命令) //baseModelPath := filepath.Join("data/models", model.BaseModel) //fmt.Println("baseModelPath:", baseModelPath) //err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run() //if err != nil { // fmt.Println(err) // return err //} //// 進行訓練(訓練機上調用訓練webapi接口:參數) //resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil) //if err != nil { // fmt.Println(err) // return err //} //defer resp.Body.Close() //// 循環監聽訓練進度 //for i := 0; i < 5; i++ { // // 訓練機上調用訓練webapi接口:獲取訓練進度 // resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP)) // if err != nil { // fmt.Println(err) // return err // } // defer resp.Body.Close() //// 更新本地訓練進度 // var progress int // if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil { // fmt.Println(err) // return err // } //} // // TODO: 訓練完成後將模型下載到本地 return nil }