From 0966a3c83e39583aa4acfa08b86463531ffb10d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A1=9C=E8=8F=AF?= Date: Sat, 27 May 2023 16:40:04 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8B=E8=BC=89=E5=9C=96=E5=83=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/Dataset.go | 21 +++++++ models/Image.go | 12 ++++ models/Model.go | 137 ++++++++++++++++++++++++++++++++++++++++---- models/Server.go | 11 +++- routers/datasets.go | 44 +++----------- routers/models.go | 2 + test.sh | 38 ++++++------ 7 files changed, 203 insertions(+), 62 deletions(-) create mode 100644 models/Dataset.go diff --git a/models/Dataset.go b/models/Dataset.go new file mode 100644 index 0000000..464a0d9 --- /dev/null +++ b/models/Dataset.go @@ -0,0 +1,21 @@ +package models + +import ( + "main/configs" + "time" +) + +// 數據集模型 +type Dataset struct { + ID int `json:"id" gorm:"primary_key"` + Name string `json:"name"` + Info string `json:"info"` + Images ImageList `json:"images"` + UserID int `json:"user_id"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +func init() { + configs.ORMDB().AutoMigrate(&Dataset{}) +} diff --git a/models/Image.go b/models/Image.go index 6e49b3a..ea24499 100644 --- a/models/Image.go +++ b/models/Image.go @@ -1,6 +1,8 @@ package models import ( + "database/sql/driver" + "encoding/json" "main/configs" "time" ) @@ -25,3 +27,13 @@ type Image struct { func init() { configs.ORMDB().AutoMigrate(&Image{}) } + +type ImageList []string + +func (list *ImageList) Scan(value interface{}) error { + return json.Unmarshal(value.([]byte), list) +} + +func (list ImageList) Value() (driver.Value, error) { + return json.Marshal(list) +} diff --git a/models/Model.go b/models/Model.go index 47369b9..6f788c2 100644 --- a/models/Model.go +++ b/models/Model.go @@ -1,8 +1,15 @@ package models import ( + "crypto/md5" + "encoding/json" "fmt" + "io/ioutil" "main/configs" + "net/http" + "os" + "os/exec" + "path/filepath" "time" ) @@ -21,6 +28,7 @@ type Model struct { Tags TagList `json:"tags"` UserID int `json:"user_id"` DatasetID int `json:"dataset_id"` + ServerID int `json:"server_id"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } @@ -30,21 +38,130 @@ func init() { } func (model *Model) Train() (err error) { - if model.Type == "lora" { - fmt.Println("lora") + + // 獲取一臺空閒的訓練機 + var server Server + if err = configs.ORMDB().Where("status = ?", "閒置").First(&server).Error; err != nil { + fmt.Println(err) return } - if model.Type == "ckp" { - fmt.Println("ckp") + + // 獲取數據集 + var dataset Dataset = Dataset{ID: model.DatasetID} + if err = configs.ORMDB().First(&dataset).Error; err != nil { + fmt.Println(err) return } - if model.Type == "hyper" { - fmt.Println("hyper") + + // 更新模型狀態 + model.ServerID = server.ID + model.Status = "training" + if err = configs.ORMDB().Save(&model).Error; err != nil { + fmt.Println(err) return } - if model.Type == "ti" { - fmt.Println("ti") - return + + // 創建數據集目錄 + dirPath := filepath.Join("data/datasets", fmt.Sprint(dataset.ID), "images") + if err := os.MkdirAll(dirPath, 0755); err != nil { + fmt.Println(err) + return err } - return + + // 將數據下載到本地 + for index, url := range dataset.Images { + fmt.Println("下載數據到本地:", index, url) + + // 下載到臨時目錄 + resp, err := http.Get(url) + if err != nil { + fmt.Println("下載失敗:", err) + continue + } + defer resp.Body.Close() + + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Println("保存失敗:", err) + continue + } + + // 保存文件到本地目錄下(自動創建目錄,文件名為url的md5值) + filename := fmt.Sprintf("%x", md5.Sum([]byte(url))) + filePath := filepath.Join(dirPath, filename) + if err := os.MkdirAll(dirPath, 0755); err != nil { + fmt.Println(err) + continue + } + + if err := ioutil.WriteFile(filePath, data, 0644); err != nil { + fmt.Println(err) + continue + } + + } + + fmt.Println("數據下載完成") + + // 檢查目錄下是否有文件, 如果沒有文件則返回錯誤 + files, err := ioutil.ReadDir(dirPath) + if err != nil { + fmt.Println(err) + return err + } + + if len(files) == 0 { + fmt.Println("目錄下沒有文件") + return fmt.Errorf("目錄下沒有文件") + } + + // 將文件全部上傳到訓練機(使用scp命令) + err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run() + if err != nil { + fmt.Println(err) + return err + } + + // 刪除本地臨時目錄 + if err := os.RemoveAll(dirPath); err != nil { + fmt.Println(err) + return err + } + + // 将基础模型上传到训练机(使用scp命令) + baseModelPath := filepath.Join("data/models", model.BaseModel) + fmt.Println("baseModelPath:", baseModelPath) + err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run() + if err != nil { + fmt.Println(err) + return err + } + + // 進行訓練(訓練機上調用訓練webapi接口:參數) + resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil) + if err != nil { + fmt.Println(err) + return err + } + defer resp.Body.Close() + + // 循環監聽訓練進度 + for { + // 訓練機上調用訓練webapi接口:獲取訓練進度 + resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP)) + if err != nil { + fmt.Println(err) + return err + } + defer resp.Body.Close() + + // 更新本地訓練進度 + var progress int + if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil { + fmt.Println(err) + return err + } + } + + // TODO: 訓練完成後將模型下載到本地 } diff --git a/models/Server.go b/models/Server.go index f18ad90..9f00550 100644 --- a/models/Server.go +++ b/models/Server.go @@ -11,7 +11,7 @@ type Server struct { Type string `json:"type"` // (訓練|推理) IP string `json:"ip"` Port int `json:"port"` - Status string `json:"status"` // (異常|初始化|就緒|工作中|關閉中) + Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中) Username string `json:"username"` Password string `json:"password"` Models []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存 @@ -21,4 +21,13 @@ type Server struct { func init() { configs.ORMDB().AutoMigrate(&Server{}) + + // 添加一個預設的訓練機 + configs.ORMDB().Create(&Server{ + Name: "GPU T4", + Type: "train", + IP: "106.15.192.42", + Port: 7860, + Status: "閒置", + }) } diff --git a/routers/datasets.go b/routers/datasets.go index a224c2c..5b88fac 100644 --- a/routers/datasets.go +++ b/routers/datasets.go @@ -1,57 +1,31 @@ package routers import ( - "database/sql/driver" "encoding/json" "io/ioutil" "main/configs" "main/models" "main/utils" "net/http" - "time" "github.com/gorilla/mux" ) -type ImageList []string - -func (list *ImageList) Scan(value interface{}) error { - return json.Unmarshal(value.([]byte), list) -} -func (list ImageList) Value() (driver.Value, error) { - return json.Marshal(list) -} - -// 數據集模型 -type Dataset struct { - ID int `json:"id" gorm:"primary_key"` - Name string `json:"name"` - Info string `json:"info"` - Images ImageList `json:"images"` - UserID int `json:"user_id"` - CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` - UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` -} - -func init() { - configs.ORMDB().AutoMigrate(&Dataset{}) -} - // 獲取數據集列表 func DatasetsGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) - var dataset_list []Dataset + var dataset_list []models.Dataset db := configs.ORMDB() db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&dataset_list) for _, dataset := range dataset_list { if dataset.Images == nil { - dataset.Images = ImageList{} + dataset.Images = models.ImageList{} } listview.List = append(listview.List, dataset) } - db.Model(&Dataset{}).Count(&listview.Total) + db.Model(&models.Dataset{}).Count(&listview.Total) listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } @@ -59,7 +33,7 @@ func DatasetsGet(w http.ResponseWriter, r *http.Request) { // 創建數據集 func DatasetsPost(w http.ResponseWriter, r *http.Request) { models.AccountRead(w, r, func(account *models.Account) { - var dataset Dataset + var dataset models.Dataset body, err := ioutil.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -73,7 +47,7 @@ func DatasetsPost(w http.ResponseWriter, r *http.Request) { return } if dataset.Images == nil { - dataset.Images = ImageList{} + dataset.Images = models.ImageList{} } dataset.UserID = account.ID if err := configs.ORMDB().Create(&dataset).Error; err != nil { @@ -88,7 +62,7 @@ func DatasetsPost(w http.ResponseWriter, r *http.Request) { // 獲取數據集 func DatasetsItemGet(w http.ResponseWriter, r *http.Request) { - dataset := Dataset{ID: utils.ParamInt(mux.Vars(r)["dataset_id"], 0)} + dataset := models.Dataset{ID: utils.ParamInt(mux.Vars(r)["dataset_id"], 0)} if err := configs.ORMDB().Find(&dataset).Error; err != nil { w.WriteHeader(http.StatusNotFound) w.Write([]byte("404 - Not Found")) @@ -101,7 +75,7 @@ func DatasetsItemGet(w http.ResponseWriter, r *http.Request) { // 修改數據集 func DatasetsItemPatch(w http.ResponseWriter, r *http.Request) { models.AccountRead(w, r, func(account *models.Account) { - var dataset Dataset = Dataset{ID: utils.ParamInt(mux.Vars(r)["dataset_id"], 0)} + var dataset models.Dataset = models.Dataset{ID: utils.ParamInt(mux.Vars(r)["dataset_id"], 0)} if err := configs.ORMDB().Find(&dataset).Error; err != nil { w.WriteHeader(http.StatusNotFound) w.Write([]byte("404 - Not Found")) @@ -120,7 +94,7 @@ func DatasetsItemPatch(w http.ResponseWriter, r *http.Request) { dataset.Info = info } if images, ok := form["images"].([]interface{}); ok { - var image_list ImageList + var image_list models.ImageList for _, image := range images { if image, ok := image.(string); ok { image_list = append(image_list, image) @@ -142,7 +116,7 @@ func DatasetsItemPatch(w http.ResponseWriter, r *http.Request) { func DatasetsItemDelete(w http.ResponseWriter, r *http.Request) { models.AccountRead(w, r, func(account *models.Account) { // 獲取數據集 - dataset := Dataset{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} + dataset := models.Dataset{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Find(&dataset).Error; err != nil { w.WriteHeader(http.StatusNotFound) w.Write([]byte("404 - Not Found")) diff --git a/routers/models.go b/routers/models.go index cfbf84f..16784d4 100644 --- a/routers/models.go +++ b/routers/models.go @@ -92,8 +92,10 @@ func ModelsPost(w http.ResponseWriter, r *http.Request) { log.Println(err) return } + // 直接提交訓練任務 go model.Train() + // 返回創建的模型 w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(model)) diff --git a/test.sh b/test.sh index 10f11d4..fd039da 100755 --- a/test.sh +++ b/test.sh @@ -2,6 +2,7 @@ # 記錄開始時間戳 start_time=$(date +%s) +rm -f data/sqlite3.db # 流程測試, 啓動服務, 設定進程名 go_test go run main.go -procname go_test & @@ -48,7 +49,7 @@ echo "dataset_id: $dataset_id" # 修改數據集, images 中增加 url (PATCH /api/datasets/:id) -response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png"]}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/datasets/$dataset_id) +response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["https://img.gameui.net/article-7258-1677745322000@1x456.webp","https://img.gameui.net/article-6477-1682109454000@1x456.webp"]}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/datasets/$dataset_id) [[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}" @@ -72,22 +73,27 @@ echo "model_id: $model_id" #[[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}" -# 循環獲取模型訓練進度, 直到訓練完成 -while true; do - # 獲取模型訓練進度 (GET /api/models/:id) - response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models/$model_id) - [[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}" +## 循環獲取模型訓練進度, 直到訓練完成 +#while true; do +# # 獲取模型訓練進度 (GET /api/models/:id) +# response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models/$model_id) +# [[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}" +# +# # 取出進度字段的值, 值爲 int +# progress=$(echo "${response%???}" | grep -o '"progress": [0-9]*' | awk '{print $2}') +# echo "progress: $progress" +# +# # 如果進度爲 100, 訓練完成, 跳出循環 +# [[ $progress -eq 100 ]] && { echo "訓練完成"; break; } +# +# # 休眠 5 秒 +# sleep 5 +#done - # 取出進度字段的值, 值爲 int - progress=$(echo "${response%???}" | grep -o '"progress": [0-9]*' | awk '{print $2}') - echo "progress: $progress" - - # 如果進度爲 100, 訓練完成, 跳出循環 - [[ $progress -eq 100 ]] && { echo "訓練完成"; break; } - - # 休眠 5 秒 - sleep 5 -done +# 服務器列表 +response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/servers) +[[ ${response: -3} -eq 200 ]] && { echo "獲取服務器列表成功: ${response%???}"; } || exit_service "獲取服務器列表失敗: ${response%???}" +sleep 10 exit_service "測試結束, 全部通過, 用費時間: $(($(date +%s) - $start_time)) 秒"