package models import ( "crypto/md5" "encoding/json" "fmt" "io/ioutil" "main/configs" "net/http" "os" "os/exec" "path/filepath" "time" ) type Model struct { ID int `json:"id" gorm:"primary_key"` Name string `json:"name"` Type string `json:"type"` // (lora|ckp|hyper|ti) TriggerWords string `json:"trigger_words"` // 觸發詞 BaseModel string `json:"base_model"` // (SD1.5|SD2) ModelPath string `json:"model_path"` // 模型路徑 Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error) Progress int `json:"progress"` // (0-100) Image string `json:"image"` // 封面圖片實際地址 Hash string `json:"hash"` // 模型哈希值 Epochs int `json:"epochs"` // 訓練步數 Tags TagList `json:"tags"` UserID int `json:"user_id"` DatasetID int `json:"dataset_id"` ServerID int `json:"server_id"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } func init() { configs.ORMDB().AutoMigrate(&Model{}) } func (model *Model) Train() (err error) { // 獲取一臺空閒的訓練機 var server Server if err = configs.ORMDB().Where("status = ?", "閒置").First(&server).Error; err != nil { fmt.Println(err) return } // 獲取數據集 var dataset Dataset = Dataset{ID: model.DatasetID} if err = configs.ORMDB().First(&dataset).Error; err != nil { fmt.Println(err) return } // 更新模型狀態 model.ServerID = server.ID model.Status = "training" if err = configs.ORMDB().Save(&model).Error; err != nil { fmt.Println(err) return } // 創建數據集目錄 dirPath := filepath.Join("data/datasets", fmt.Sprint(dataset.ID), "images") if err := os.MkdirAll(dirPath, 0755); err != nil { fmt.Println(err) return err } // 將數據下載到本地 for index, url := range dataset.Images { fmt.Println("下載數據到本地:", index, url) // 下載到臨時目錄 resp, err := http.Get(url) if err != nil { fmt.Println("下載失敗:", err) continue } defer resp.Body.Close() data, err := ioutil.ReadAll(resp.Body) if err != nil { fmt.Println("保存失敗:", err) continue } // 保存文件到本地目錄下(自動創建目錄,文件名為url的md5值) filename := fmt.Sprintf("%x", md5.Sum([]byte(url))) filePath := filepath.Join(dirPath, filename) if err := os.MkdirAll(dirPath, 0755); err != nil { fmt.Println(err) continue } if err := ioutil.WriteFile(filePath, data, 0644); err != nil { fmt.Println(err) continue } } fmt.Println("數據下載完成") // 檢查目錄下是否有文件, 如果沒有文件則返回錯誤 files, err := ioutil.ReadDir(dirPath) if err != nil { fmt.Println(err) return err } if len(files) == 0 { fmt.Println("目錄下沒有文件") return fmt.Errorf("目錄下沒有文件") } // 將文件全部上傳到訓練機(使用scp命令) err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run() if err != nil { fmt.Println(err) return err } // 刪除本地臨時目錄 if err := os.RemoveAll(dirPath); err != nil { fmt.Println(err) return err } // 将基础模型上传到训练机(使用scp命令) baseModelPath := filepath.Join("data/models", model.BaseModel) fmt.Println("baseModelPath:", baseModelPath) err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run() if err != nil { fmt.Println(err) return err } // 進行訓練(訓練機上調用訓練webapi接口:參數) resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil) if err != nil { fmt.Println(err) return err } defer resp.Body.Close() // 循環監聽訓練進度 for { // 訓練機上調用訓練webapi接口:獲取訓練進度 resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP)) if err != nil { fmt.Println(err) return err } defer resp.Body.Close() // 更新本地訓練進度 var progress int if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil { fmt.Println(err) return err } } // TODO: 訓練完成後將模型下載到本地 }