diff --git a/models/Model.go b/models/Model.go index 455df18..47369b9 100644 --- a/models/Model.go +++ b/models/Model.go @@ -17,8 +17,10 @@ type Model struct { Progress int `json:"progress"` // (0-100) Image string `json:"image"` // 封面圖片實際地址 Hash string `json:"hash"` // 模型哈希值 - Tags string `json:"tags"` + Epochs int `json:"epochs"` // 訓練步數 + Tags TagList `json:"tags"` UserID int `json:"user_id"` + DatasetID int `json:"dataset_id"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } diff --git a/models/Tag.go b/models/Tag.go index cc386d5..14e9856 100644 --- a/models/Tag.go +++ b/models/Tag.go @@ -1,6 +1,8 @@ package models import ( + "database/sql/driver" + "encoding/json" "main/configs" "time" ) @@ -15,3 +17,13 @@ type Tag struct { func init() { configs.ORMDB().AutoMigrate(&Tag{}) } + +type TagList []string + +func (list *TagList) Scan(value interface{}) error { + return json.Unmarshal(value.([]byte), list) +} + +func (list TagList) Value() (driver.Value, error) { + return json.Marshal(list) +} diff --git a/routers/models.go b/routers/models.go index 8d89f9f..cfbf84f 100644 --- a/routers/models.go +++ b/routers/models.go @@ -33,7 +33,7 @@ func ModelsGet(w http.ResponseWriter, r *http.Request) { listview.WriteJSON(w) } -// 創建模型 +// 創建模型(訓練新模型) func ModelsPost(w http.ResponseWriter, r *http.Request) { models.AccountRead(w, r, func(account *models.Account) { fmt.Println(account) @@ -50,6 +50,44 @@ func ModelsPost(w http.ResponseWriter, r *http.Request) { log.Println(err) return } + + if model.Name == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("模型名稱不能為空")) + return + } + + if model.Type == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("模型類型不能為空(recommend|lora|ckp|hyper|ti)")) + return + } + + if model.TriggerWords == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("觸發詞不能為空")) + return + } + + if model.BaseModel == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("基礎模型不能為空(SD1.5|SD2)")) + return + } + + if model.Epochs <= 0 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("訓練輪數不能小於0")) + return + } + + if model.Tags == nil { + model.Tags = []string{} + } + + model.UserID = account.ID + model.Status = "initial" + if err := configs.ORMDB().Create(&model).Error; err != nil { log.Println(err) return @@ -62,6 +100,7 @@ func ModelsPost(w http.ResponseWriter, r *http.Request) { }) } +// 獲取模型詳情 func ModelItemGet(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") == "websocket" { vars := mux.Vars(r) @@ -97,8 +136,10 @@ func ModelItemGet(w http.ResponseWriter, r *http.Request) { } var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil { + fmt.Println(model) + if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil { w.WriteHeader(http.StatusNotFound) + w.Write([]byte(err.Error())) return } @@ -106,6 +147,7 @@ func ModelItemGet(w http.ResponseWriter, r *http.Request) { w.Write(utils.ToJSON(model)) } +// 更新模型 func ModelItemPatch(w http.ResponseWriter, r *http.Request) { var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil { @@ -157,6 +199,7 @@ func ModelItemPatch(w http.ResponseWriter, r *http.Request) { w.Write(utils.ToJSON(model)) } +// 刪除模型 func ModelItemDelete(w http.ResponseWriter, r *http.Request) { var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil { diff --git a/test.sh b/test.sh index c4ee3a3..10f11d4 100755 --- a/test.sh +++ b/test.sh @@ -1,5 +1,7 @@ #!/bin/bash +# 記錄開始時間戳 +start_time=$(date +%s) # 流程測試, 啓動服務, 設定進程名 go_test go run main.go -procname go_test & @@ -50,4 +52,42 @@ response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["htt [[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}" -exit_service "測試結束, 全部通過" +# 訓練模型 (POST /api/models) +response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"lora","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) +[[ ${response: -3} -eq 200 ]] && { echo "訓練模型任務已創建: ${response%???}"; } || exit_service "訓練模型任務創建失敗: ${response%???}" + + +# 取模型id的值, 值爲 int +model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}') +echo "model_id: $model_id" + + +## 模型列表 (GET /api/models) +# response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) +# [[ ${response: -3} -eq 200 ]] && { echo "獲取模型列表成功: ${response%???}"; } || exit_service "獲取模型列表失敗: ${response%???}" + + +## 獲取模型訓練進度 (GET /api/models/:id) +#response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models/$model_id) +#[[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}" + + +# 循環獲取模型訓練進度, 直到訓練完成 +while true; do + # 獲取模型訓練進度 (GET /api/models/:id) + response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models/$model_id) + [[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}" + + # 取出進度字段的值, 值爲 int + progress=$(echo "${response%???}" | grep -o '"progress": [0-9]*' | awk '{print $2}') + echo "progress: $progress" + + # 如果進度爲 100, 訓練完成, 跳出循環 + [[ $progress -eq 100 ]] && { echo "訓練完成"; break; } + + # 休眠 5 秒 + sleep 5 +done + + +exit_service "測試結束, 全部通過, 用費時間: $(($(date +%s) - $start_time)) 秒"