模型訓練
This commit is contained in:
117
models/Model.go
117
models/Model.go
@@ -2,13 +2,12 @@ package models
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"main/configs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
@@ -115,53 +114,91 @@ func (model *Model) Train() (err error) {
|
||||
return fmt.Errorf("目錄下沒有文件")
|
||||
}
|
||||
|
||||
// 將文件全部上傳到訓練機(使用scp命令)
|
||||
err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return err
|
||||
// 按類型執行訓練任務
|
||||
if model.Type == "dreambooth" {
|
||||
// 創建數據庫模型
|
||||
fmt.Println("創建數據庫模型 ======================================")
|
||||
resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/dreambooth/createModel?new_model_name=%s&new_model_src=%s", server.IP, server.Port, model.Name, model.ModelPath), nil)
|
||||
if err != nil {
|
||||
fmt.Println("創建訓練任務失敗:", err.Error())
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 打印返回的結果
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Println("解碼任務數據失敗:", err)
|
||||
return err
|
||||
}
|
||||
fmt.Println("預覽:", string(body))
|
||||
|
||||
// 上傳數據到訓練機
|
||||
|
||||
// 執行訓練命令
|
||||
}
|
||||
|
||||
// 刪除本地臨時目錄
|
||||
if err := os.RemoveAll(dirPath); err != nil {
|
||||
fmt.Println(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 将基础模型上传到训练机(使用scp命令)
|
||||
baseModelPath := filepath.Join("data/models", model.BaseModel)
|
||||
fmt.Println("baseModelPath:", baseModelPath)
|
||||
err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 進行訓練(訓練機上調用訓練webapi接口:參數)
|
||||
resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 循環監聽訓練進度
|
||||
for {
|
||||
// 訓練機上調用訓練webapi接口:獲取訓練進度
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
|
||||
if model.Type == "lora" {
|
||||
// 創建數據庫模型
|
||||
formData := url.Values{}
|
||||
formData.Set("name", model.Name)
|
||||
formData.Set("type", model.Type)
|
||||
resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/lora/createModel", server.IP, server.Port), formData)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 更新本地訓練進度
|
||||
var progress int
|
||||
if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
|
||||
fmt.Println(err)
|
||||
return err
|
||||
}
|
||||
// 上傳數據到訓練機
|
||||
|
||||
// 執行訓練命令
|
||||
}
|
||||
|
||||
//// 將文件全部上傳到訓練機, 使用scp命令,自動使用密碼登錄
|
||||
//err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s@%s:~/dataset_%d", server.UserName, server.IP, model.ID)).Run()
|
||||
//if err != nil {
|
||||
// fmt.Println(err)
|
||||
// return err
|
||||
//}
|
||||
//// 刪除本地臨時目錄
|
||||
//if err := os.RemoveAll(dirPath); err != nil {
|
||||
// fmt.Println(err)
|
||||
// return err
|
||||
//}
|
||||
//// 将基础模型上传到训练机(使用scp命令)
|
||||
//baseModelPath := filepath.Join("data/models", model.BaseModel)
|
||||
//fmt.Println("baseModelPath:", baseModelPath)
|
||||
//err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
|
||||
//if err != nil {
|
||||
// fmt.Println(err)
|
||||
// return err
|
||||
//}
|
||||
//// 進行訓練(訓練機上調用訓練webapi接口:參數)
|
||||
//resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
|
||||
//if err != nil {
|
||||
// fmt.Println(err)
|
||||
// return err
|
||||
//}
|
||||
//defer resp.Body.Close()
|
||||
//// 循環監聽訓練進度
|
||||
//for i := 0; i < 5; i++ {
|
||||
// // 訓練機上調用訓練webapi接口:獲取訓練進度
|
||||
// resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
|
||||
// if err != nil {
|
||||
// fmt.Println(err)
|
||||
// return err
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
//// 更新本地訓練進度
|
||||
// var progress int
|
||||
// if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
|
||||
// fmt.Println(err)
|
||||
// return err
|
||||
// }
|
||||
//}
|
||||
//
|
||||
// TODO: 訓練完成後將模型下載到本地
|
||||
return nil
|
||||
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@ type Server struct {
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
|
||||
Username string `json:"username"`
|
||||
UserName string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Models []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
|
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// 獲取用戶列表
|
||||
// 用戶列表
|
||||
func UsersGet(w http.ResponseWriter, r *http.Request) {
|
||||
var listview models.ListView
|
||||
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
|
||||
|
24
test.sh
24
test.sh
@@ -53,19 +53,19 @@ response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["htt
|
||||
[[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}"
|
||||
|
||||
|
||||
# 訓練模型 (POST /api/models)
|
||||
response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"lora","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
|
||||
[[ ${response: -3} -eq 200 ]] && { echo "訓練模型任務已創建: ${response%???}"; } || exit_service "訓練模型任務創建失敗: ${response%???}"
|
||||
## 訓練模型 (POST /api/models)
|
||||
#response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"dreambooth","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
|
||||
#[[ ${response: -3} -eq 200 ]] && { echo "訓練模型任務已創建: ${response%???}"; } || exit_service "訓練模型任務創建失敗: ${response%???}"
|
||||
#
|
||||
#
|
||||
## 取模型id的值, 值爲 int
|
||||
#model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}')
|
||||
#echo "model_id: $model_id"
|
||||
#
|
||||
|
||||
|
||||
# 取模型id的值, 值爲 int
|
||||
model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}')
|
||||
echo "model_id: $model_id"
|
||||
|
||||
|
||||
## 模型列表 (GET /api/models)
|
||||
# response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
|
||||
# [[ ${response: -3} -eq 200 ]] && { echo "獲取模型列表成功: ${response%???}"; } || exit_service "獲取模型列表失敗: ${response%???}"
|
||||
# 模型列表 (GET /api/models)
|
||||
response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
|
||||
[[ ${response: -3} -eq 200 ]] && { echo "獲取模型列表成功: ${response%???}"; } || exit_service "獲取模型列表失敗: ${response%???}"
|
||||
|
||||
|
||||
## 獲取模型訓練進度 (GET /api/models/:id)
|
||||
|
Reference in New Issue
Block a user