diff --git a/models/Model.go b/models/Model.go index 6f788c2..d9fd153 100644 --- a/models/Model.go +++ b/models/Model.go @@ -2,13 +2,12 @@ package models import ( "crypto/md5" - "encoding/json" "fmt" "io/ioutil" "main/configs" "net/http" + "net/url" "os" - "os/exec" "path/filepath" "time" ) @@ -115,53 +114,91 @@ func (model *Model) Train() (err error) { return fmt.Errorf("目錄下沒有文件") } - // 將文件全部上傳到訓練機(使用scp命令) - err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run() - if err != nil { - fmt.Println(err) - return err + // 按類型執行訓練任務 + if model.Type == "dreambooth" { + // 創建數據庫模型 + fmt.Println("創建數據庫模型 ======================================") + resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/dreambooth/createModel?new_model_name=%s&new_model_src=%s", server.IP, server.Port, model.Name, model.ModelPath), nil) + if err != nil { + fmt.Println("創建訓練任務失敗:", err.Error()) + return err + } + defer resp.Body.Close() + + // 打印返回的結果 + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Println("解碼任務數據失敗:", err) + return err + } + fmt.Println("預覽:", string(body)) + + // 上傳數據到訓練機 + + // 執行訓練命令 } - // 刪除本地臨時目錄 - if err := os.RemoveAll(dirPath); err != nil { - fmt.Println(err) - return err - } - - // 将基础模型上传到训练机(使用scp命令) - baseModelPath := filepath.Join("data/models", model.BaseModel) - fmt.Println("baseModelPath:", baseModelPath) - err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run() - if err != nil { - fmt.Println(err) - return err - } - - // 進行訓練(訓練機上調用訓練webapi接口:參數) - resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil) - if err != nil { - fmt.Println(err) - return err - } - defer resp.Body.Close() - - // 循環監聽訓練進度 - for { - // 訓練機上調用訓練webapi接口:獲取訓練進度 - resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP)) + if model.Type == "lora" { + // 創建數據庫模型 + formData := url.Values{} + formData.Set("name", model.Name) + formData.Set("type", model.Type) + resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/lora/createModel", server.IP, server.Port), formData) if err != nil { fmt.Println(err) return err } defer resp.Body.Close() - // 更新本地訓練進度 - var progress int - if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil { - fmt.Println(err) - return err - } + // 上傳數據到訓練機 + + // 執行訓練命令 } + //// 將文件全部上傳到訓練機, 使用scp命令,自動使用密碼登錄 + //err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s@%s:~/dataset_%d", server.UserName, server.IP, model.ID)).Run() + //if err != nil { + // fmt.Println(err) + // return err + //} + //// 刪除本地臨時目錄 + //if err := os.RemoveAll(dirPath); err != nil { + // fmt.Println(err) + // return err + //} + //// 将基础模型上传到训练机(使用scp命令) + //baseModelPath := filepath.Join("data/models", model.BaseModel) + //fmt.Println("baseModelPath:", baseModelPath) + //err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run() + //if err != nil { + // fmt.Println(err) + // return err + //} + //// 進行訓練(訓練機上調用訓練webapi接口:參數) + //resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil) + //if err != nil { + // fmt.Println(err) + // return err + //} + //defer resp.Body.Close() + //// 循環監聽訓練進度 + //for i := 0; i < 5; i++ { + // // 訓練機上調用訓練webapi接口:獲取訓練進度 + // resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP)) + // if err != nil { + // fmt.Println(err) + // return err + // } + // defer resp.Body.Close() + //// 更新本地訓練進度 + // var progress int + // if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil { + // fmt.Println(err) + // return err + // } + //} + // // TODO: 訓練完成後將模型下載到本地 + return nil + } diff --git a/models/Server.go b/models/Server.go index 9f00550..4524eaa 100644 --- a/models/Server.go +++ b/models/Server.go @@ -12,7 +12,7 @@ type Server struct { IP string `json:"ip"` Port int `json:"port"` Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中) - Username string `json:"username"` + UserName string `json:"username"` Password string `json:"password"` Models []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存 CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` diff --git a/routers/users.go b/routers/users.go index e10025f..dd4e31e 100644 --- a/routers/users.go +++ b/routers/users.go @@ -12,7 +12,7 @@ import ( "github.com/gorilla/mux" ) -// 獲取用戶列表 +// 用戶列表 func UsersGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) diff --git a/test.sh b/test.sh index fd039da..54b6000 100755 --- a/test.sh +++ b/test.sh @@ -53,19 +53,19 @@ response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["htt [[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}" -# 訓練模型 (POST /api/models) -response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"lora","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) -[[ ${response: -3} -eq 200 ]] && { echo "訓練模型任務已創建: ${response%???}"; } || exit_service "訓練模型任務創建失敗: ${response%???}" +## 訓練模型 (POST /api/models) +#response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"dreambooth","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) +#[[ ${response: -3} -eq 200 ]] && { echo "訓練模型任務已創建: ${response%???}"; } || exit_service "訓練模型任務創建失敗: ${response%???}" +# +# +## 取模型id的值, 值爲 int +#model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}') +#echo "model_id: $model_id" +# - -# 取模型id的值, 值爲 int -model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}') -echo "model_id: $model_id" - - -## 模型列表 (GET /api/models) -# response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) -# [[ ${response: -3} -eq 200 ]] && { echo "獲取模型列表成功: ${response%???}"; } || exit_service "獲取模型列表失敗: ${response%???}" +# 模型列表 (GET /api/models) + response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) + [[ ${response: -3} -eq 200 ]] && { echo "獲取模型列表成功: ${response%???}"; } || exit_service "獲取模型列表失敗: ${response%???}" ## 獲取模型訓練進度 (GET /api/models/:id) diff --git a/update.sh b/update.sh index 073baef..16b0dd7 100755 --- a/update.sh +++ b/update.sh @@ -6,6 +6,7 @@ go build -o data/gameui-ai-server main.go # 上传文件 scp ./data/gameui-ai-server root@47.103.40.152:~/gameui-ai-server_new +rm -rf ./data/gameui-ai-server # 重啓服務 ssh root@47.103.40.152 '''