Files
ai/models/Model.go
2023-06-06 15:51:24 +08:00

207 lines
6.2 KiB
Go

package models
import (
"crypto/md5"
"fmt"
"io/ioutil"
"main/configs"
"net/http"
"net/url"
"os"
"path/filepath"
"time"
)
type Model struct {
ID int `json:"id" gorm:"primary_key"` // 模型ID
Name string `json:"name"` // 模型名稱
Info string `json:"info"` // 模型描述
Type string `json:"type"` // 模型類型(lora|ckp|hyper|ti)
TriggerWords string `json:"trigger_words"` // 觸發詞
BaseModel string `json:"base_model"` // 基礎模型(SD1.5|SD2)
ModelPath string `json:"model_path"` // 模型路徑(實際存放在服務器上的路徑)
Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error)
Progress int `json:"progress"` // (0-100)
Image string `json:"image"` // 封面圖片實際地址
Hash string `json:"hash"` // 模型哈希值
Epochs int `json:"epochs"` // 訓練步數
LearningRate float32 `json:"learning_rate"` // 學習率(0.000005)
Tags TagList `json:"tags"` // 模型標籤(標籤名數組)
UserID int `json:"user_id"` // 模型的所有者
DatasetID int `json:"dataset_id"` // 模型所使用的數據集ID
ServerID string `json:"server_id"` // 模型所在服務器(訓練機或推理機)
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
func init() {
configs.ORMDB().AutoMigrate(&Model{})
}
func (model *Model) Train() (err error) {
// 獲取一臺空閒的訓練機
var server Server
if err = configs.ORMDB().Where("status = ?", "正常").First(&server).Error; err != nil {
fmt.Println(err)
// TOOD: 沒有空閒的訓練機, 訓練排隊, 等待訓練機空閒
// TODO: 如果訓練機數量低於10臺, 則創建新的訓練機
return
}
// 獲取數據集
var dataset Dataset = Dataset{ID: model.DatasetID}
if err = configs.ORMDB().First(&dataset).Error; err != nil {
fmt.Println(err)
return
}
// 更新模型狀態
model.ServerID = server.ID
model.Status = "training"
if err = configs.ORMDB().Save(&model).Error; err != nil {
fmt.Println(err)
return
}
// 創建數據集目錄
dirPath := filepath.Join("data/datasets", fmt.Sprint(dataset.ID), "images")
if err := os.MkdirAll(dirPath, 0755); err != nil {
fmt.Println(err)
return err
}
// 將數據下載到本地
for index, url := range dataset.Images {
fmt.Println("下載數據到本地:", index, url)
// 檢查文件是否已經存在
filename := fmt.Sprintf("%x", md5.Sum([]byte(url)))
filePath := filepath.Join(dirPath, filename)
if _, err := os.Stat(filePath); err == nil {
fmt.Println("文件已經存在:", filePath)
continue
}
// 下載到臨時目錄
resp, err := http.Get(url)
if err != nil {
fmt.Println("下載失敗:", err)
continue
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
fmt.Println("保存失敗:", err)
continue
}
// 保存文件到本地目錄下
if err := ioutil.WriteFile(filePath, data, 0644); err != nil {
fmt.Println(err)
continue
}
}
fmt.Println("數據下載完成")
// 檢查目錄下是否有文件, 如果沒有文件則返回錯誤
files, err := ioutil.ReadDir(dirPath)
if err != nil || len(files) == 0 {
fmt.Println("目錄下沒有文件")
return fmt.Errorf("目錄下沒有文件")
}
// 將數據上傳到訓練機
// 按類型執行訓練任務
if model.Type == "dreambooth" {
// 創建數據庫模型
fmt.Println("創建數據庫模型 ======================================")
resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/dreambooth/createModel?new_model_name=%s&new_model_src=%s", server.IP, server.Port, model.Name, model.ModelPath), nil)
if err != nil {
fmt.Println("創建訓練任務失敗:", err.Error())
return err
}
defer resp.Body.Close()
// 打印返回的結果
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
fmt.Println("解碼任務數據失敗:", err)
return err
}
fmt.Println("預覽:", string(body))
// 上傳數據到訓練機
// 執行訓練命令
}
if model.Type == "lora" {
// 創建數據庫模型
formData := url.Values{}
formData.Set("name", model.Name)
formData.Set("type", model.Type)
resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/lora/createModel", server.IP, server.Port), formData)
if err != nil {
fmt.Println(err)
return err
}
defer resp.Body.Close()
// 上傳數據到訓練機
// 執行訓練命令
}
//// 將文件全部上傳到訓練機, 使用scp命令,自動使用密碼登錄
//err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s@%s:~/dataset_%d", server.UserName, server.IP, model.ID)).Run()
//if err != nil {
// fmt.Println(err)
// return err
//}
//// 刪除本地臨時目錄
//if err := os.RemoveAll(dirPath); err != nil {
// fmt.Println(err)
// return err
//}
//// 将基础模型上传到训练机(使用scp命令)
//baseModelPath := filepath.Join("data/models", model.BaseModel)
//fmt.Println("baseModelPath:", baseModelPath)
//err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
//if err != nil {
// fmt.Println(err)
// return err
//}
//// 進行訓練(訓練機上調用訓練webapi接口:參數)
//resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
//if err != nil {
// fmt.Println(err)
// return err
//}
//defer resp.Body.Close()
//// 循環監聽訓練進度
//for i := 0; i < 5; i++ {
// // 訓練機上調用訓練webapi接口:獲取訓練進度
// resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
// if err != nil {
// fmt.Println(err)
// return err
// }
// defer resp.Body.Close()
//// 更新本地訓練進度
// var progress int
// if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
// fmt.Println(err)
// return err
// }
//}
//
// TODO: 訓練完成後將模型下載到本地
return nil
}