模型訓練

This commit is contained in:
2023-05-28 00:44:13 +08:00
parent 0966a3c83e
commit ed7e09e736
5 changed files with 92 additions and 54 deletions

View File

@@ -2,13 +2,12 @@ package models
import ( import (
"crypto/md5" "crypto/md5"
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"main/configs" "main/configs"
"net/http" "net/http"
"net/url"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"time" "time"
) )
@@ -115,53 +114,91 @@ func (model *Model) Train() (err error) {
return fmt.Errorf("目錄下沒有文件") return fmt.Errorf("目錄下沒有文件")
} }
// 將文件全部上傳到訓練機(使用scp命令) // 按類型執行訓練任務
err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run() if model.Type == "dreambooth" {
// 創建數據庫模型
fmt.Println("創建數據庫模型 ======================================")
resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/dreambooth/createModel?new_model_name=%s&new_model_src=%s", server.IP, server.Port, model.Name, model.ModelPath), nil)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println("創建訓練任務失敗:", err.Error())
return err return err
} }
defer resp.Body.Close()
// 刪除本地臨時目錄 // 打印返回的結果
if err := os.RemoveAll(dirPath); err != nil { body, err := ioutil.ReadAll(resp.Body)
fmt.Println(err)
return err
}
// 将基础模型上传到训练机(使用scp命令)
baseModelPath := filepath.Join("data/models", model.BaseModel)
fmt.Println("baseModelPath:", baseModelPath)
err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println("解碼任務數據失敗:", err)
return err return err
} }
fmt.Println("預覽:", string(body))
// 進行訓練(訓練機上調用訓練webapi接口:參數) // 上傳數據到訓練機
resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
// 執行訓練命令
}
if model.Type == "lora" {
// 創建數據庫模型
formData := url.Values{}
formData.Set("name", model.Name)
formData.Set("type", model.Type)
resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/lora/createModel", server.IP, server.Port), formData)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
// 循環監聽訓練進度 // 上傳數據到訓練機
for {
// 訓練機上調用訓練webapi接口:獲取訓練進度
resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
if err != nil {
fmt.Println(err)
return err
}
defer resp.Body.Close()
// 更新本地訓練進度 // 執行訓練命令
var progress int
if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
fmt.Println(err)
return err
}
} }
//// 將文件全部上傳到訓練機, 使用scp命令,自動使用密碼登錄
//err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s@%s:~/dataset_%d", server.UserName, server.IP, model.ID)).Run()
//if err != nil {
// fmt.Println(err)
// return err
//}
//// 刪除本地臨時目錄
//if err := os.RemoveAll(dirPath); err != nil {
// fmt.Println(err)
// return err
//}
//// 将基础模型上传到训练机(使用scp命令)
//baseModelPath := filepath.Join("data/models", model.BaseModel)
//fmt.Println("baseModelPath:", baseModelPath)
//err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
//if err != nil {
// fmt.Println(err)
// return err
//}
//// 進行訓練(訓練機上調用訓練webapi接口:參數)
//resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
//if err != nil {
// fmt.Println(err)
// return err
//}
//defer resp.Body.Close()
//// 循環監聽訓練進度
//for i := 0; i < 5; i++ {
// // 訓練機上調用訓練webapi接口:獲取訓練進度
// resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
// if err != nil {
// fmt.Println(err)
// return err
// }
// defer resp.Body.Close()
//// 更新本地訓練進度
// var progress int
// if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
// fmt.Println(err)
// return err
// }
//}
//
// TODO: 訓練完成後將模型下載到本地 // TODO: 訓練完成後將模型下載到本地
return nil
} }

View File

@@ -12,7 +12,7 @@ type Server struct {
IP string `json:"ip"` IP string `json:"ip"`
Port int `json:"port"` Port int `json:"port"`
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中) Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
Username string `json:"username"` UserName string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Models []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存 Models []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`

View File

@@ -12,7 +12,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
// 獲取用戶列表 // 用戶列表
func UsersGet(w http.ResponseWriter, r *http.Request) { func UsersGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)

24
test.sh
View File

@@ -53,19 +53,19 @@ response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["htt
[[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}" [[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}"
# 訓練模型 (POST /api/models) ## 訓練模型 (POST /api/models)
response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"lora","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) #response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"dreambooth","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
[[ ${response: -3} -eq 200 ]] && { echo "訓練模型任務已創建: ${response%???}"; } || exit_service "訓練模型任務創建失敗: ${response%???}" #[[ ${response: -3} -eq 200 ]] && { echo "訓練模型任務已創建: ${response%???}"; } || exit_service "訓練模型任務創建失敗: ${response%???}"
#
#
## 取模型id的值, 值爲 int
#model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}')
#echo "model_id: $model_id"
#
# 模型列表 (GET /api/models)
# 取模型id的值, 值爲 int response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}') [[ ${response: -3} -eq 200 ]] && { echo "獲取模型列表成功: ${response%???}"; } || exit_service "獲取模型列表失敗: ${response%???}"
echo "model_id: $model_id"
## 模型列表 (GET /api/models)
# response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
# [[ ${response: -3} -eq 200 ]] && { echo "獲取模型列表成功: ${response%???}"; } || exit_service "獲取模型列表失敗: ${response%???}"
## 獲取模型訓練進度 (GET /api/models/:id) ## 獲取模型訓練進度 (GET /api/models/:id)

View File

@@ -6,6 +6,7 @@ go build -o data/gameui-ai-server main.go
# 上传文件 # 上传文件
scp ./data/gameui-ai-server root@47.103.40.152:~/gameui-ai-server_new scp ./data/gameui-ai-server root@47.103.40.152:~/gameui-ai-server_new
rm -rf ./data/gameui-ai-server
# 重啓服務 # 重啓服務
ssh root@47.103.40.152 ''' ssh root@47.103.40.152 '''