訓練進度測試
This commit is contained in:
		@@ -17,8 +17,10 @@ type Model struct {
 | 
				
			|||||||
	Progress     int       `json:"progress"`                 // (0-100)
 | 
						Progress     int       `json:"progress"`                 // (0-100)
 | 
				
			||||||
	Image        string    `json:"image"`                    // 封面圖片實際地址
 | 
						Image        string    `json:"image"`                    // 封面圖片實際地址
 | 
				
			||||||
	Hash         string    `json:"hash"`                     // 模型哈希值
 | 
						Hash         string    `json:"hash"`                     // 模型哈希值
 | 
				
			||||||
	Tags         string    `json:"tags"`
 | 
						Epochs       int       `json:"epochs"`                   // 訓練步數
 | 
				
			||||||
 | 
						Tags         TagList   `json:"tags"`
 | 
				
			||||||
	UserID       int       `json:"user_id"`
 | 
						UserID       int       `json:"user_id"`
 | 
				
			||||||
 | 
						DatasetID    int       `json:"dataset_id"`
 | 
				
			||||||
	CreatedAt    time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
						CreatedAt    time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
				
			||||||
	UpdatedAt    time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
						UpdatedAt    time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,8 @@
 | 
				
			|||||||
package models
 | 
					package models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql/driver"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
	"main/configs"
 | 
						"main/configs"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -15,3 +17,13 @@ type Tag struct {
 | 
				
			|||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
	configs.ORMDB().AutoMigrate(&Tag{})
 | 
						configs.ORMDB().AutoMigrate(&Tag{})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TagList []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (list *TagList) Scan(value interface{}) error {
 | 
				
			||||||
 | 
						return json.Unmarshal(value.([]byte), list)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (list TagList) Value() (driver.Value, error) {
 | 
				
			||||||
 | 
						return json.Marshal(list)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -33,7 +33,7 @@ func ModelsGet(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	listview.WriteJSON(w)
 | 
						listview.WriteJSON(w)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 創建模型
 | 
					// 創建模型(訓練新模型)
 | 
				
			||||||
func ModelsPost(w http.ResponseWriter, r *http.Request) {
 | 
					func ModelsPost(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	models.AccountRead(w, r, func(account *models.Account) {
 | 
						models.AccountRead(w, r, func(account *models.Account) {
 | 
				
			||||||
		fmt.Println(account)
 | 
							fmt.Println(account)
 | 
				
			||||||
@@ -50,6 +50,44 @@ func ModelsPost(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
			log.Println(err)
 | 
								log.Println(err)
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if model.Name == "" {
 | 
				
			||||||
 | 
								w.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
 | 
								w.Write([]byte("模型名稱不能為空"))
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if model.Type == "" {
 | 
				
			||||||
 | 
								w.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
 | 
								w.Write([]byte("模型類型不能為空(recommend|lora|ckp|hyper|ti)"))
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if model.TriggerWords == "" {
 | 
				
			||||||
 | 
								w.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
 | 
								w.Write([]byte("觸發詞不能為空"))
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if model.BaseModel == "" {
 | 
				
			||||||
 | 
								w.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
 | 
								w.Write([]byte("基礎模型不能為空(SD1.5|SD2)"))
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if model.Epochs <= 0 {
 | 
				
			||||||
 | 
								w.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
 | 
								w.Write([]byte("訓練輪數不能小於0"))
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if model.Tags == nil {
 | 
				
			||||||
 | 
								model.Tags = []string{}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							model.UserID = account.ID
 | 
				
			||||||
 | 
							model.Status = "initial"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if err := configs.ORMDB().Create(&model).Error; err != nil {
 | 
							if err := configs.ORMDB().Create(&model).Error; err != nil {
 | 
				
			||||||
			log.Println(err)
 | 
								log.Println(err)
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
@@ -62,6 +100,7 @@ func ModelsPost(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 獲取模型詳情
 | 
				
			||||||
func ModelItemGet(w http.ResponseWriter, r *http.Request) {
 | 
					func ModelItemGet(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	if r.Header.Get("Upgrade") == "websocket" {
 | 
						if r.Header.Get("Upgrade") == "websocket" {
 | 
				
			||||||
		vars := mux.Vars(r)
 | 
							vars := mux.Vars(r)
 | 
				
			||||||
@@ -97,8 +136,10 @@ func ModelItemGet(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
 | 
						var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
 | 
				
			||||||
	if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil {
 | 
						fmt.Println(model)
 | 
				
			||||||
 | 
						if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
 | 
				
			||||||
		w.WriteHeader(http.StatusNotFound)
 | 
							w.WriteHeader(http.StatusNotFound)
 | 
				
			||||||
 | 
							w.Write([]byte(err.Error()))
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -106,6 +147,7 @@ func ModelItemGet(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	w.Write(utils.ToJSON(model))
 | 
						w.Write(utils.ToJSON(model))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 更新模型
 | 
				
			||||||
func ModelItemPatch(w http.ResponseWriter, r *http.Request) {
 | 
					func ModelItemPatch(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
 | 
						var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
 | 
				
			||||||
	if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
 | 
						if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
 | 
				
			||||||
@@ -157,6 +199,7 @@ func ModelItemPatch(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	w.Write(utils.ToJSON(model))
 | 
						w.Write(utils.ToJSON(model))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 刪除模型
 | 
				
			||||||
func ModelItemDelete(w http.ResponseWriter, r *http.Request) {
 | 
					func ModelItemDelete(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
 | 
						var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
 | 
				
			||||||
	if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil {
 | 
						if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										42
									
								
								test.sh
									
									
									
									
									
								
							
							
						
						
									
										42
									
								
								test.sh
									
									
									
									
									
								
							@@ -1,5 +1,7 @@
 | 
				
			|||||||
#!/bin/bash
 | 
					#!/bin/bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 記錄開始時間戳
 | 
				
			||||||
 | 
					start_time=$(date +%s)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 流程測試, 啓動服務, 設定進程名 go_test
 | 
					# 流程測試, 啓動服務, 設定進程名 go_test
 | 
				
			||||||
go run main.go -procname go_test &
 | 
					go run main.go -procname go_test &
 | 
				
			||||||
@@ -50,4 +52,42 @@ response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["htt
 | 
				
			|||||||
[[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}"
 | 
					[[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
exit_service "測試結束, 全部通過"
 | 
					# 訓練模型 (POST /api/models)
 | 
				
			||||||
 | 
					response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"lora","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
 | 
				
			||||||
 | 
					[[ ${response: -3} -eq 200 ]] && { echo "訓練模型任務已創建: ${response%???}"; } || exit_service "訓練模型任務創建失敗: ${response%???}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 取模型id的值, 值爲 int
 | 
				
			||||||
 | 
					model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}')
 | 
				
			||||||
 | 
					echo "model_id: $model_id"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 模型列表 (GET /api/models)
 | 
				
			||||||
 | 
					# response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
 | 
				
			||||||
 | 
					# [[ ${response: -3} -eq 200 ]] && { echo "獲取模型列表成功: ${response%???}"; } || exit_service "獲取模型列表失敗: ${response%???}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 獲取模型訓練進度 (GET /api/models/:id)
 | 
				
			||||||
 | 
					#response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models/$model_id)
 | 
				
			||||||
 | 
					#[[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 循環獲取模型訓練進度, 直到訓練完成
 | 
				
			||||||
 | 
					while true; do
 | 
				
			||||||
 | 
					    # 獲取模型訓練進度 (GET /api/models/:id)
 | 
				
			||||||
 | 
					    response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models/$model_id)
 | 
				
			||||||
 | 
					    [[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 取出進度字段的值, 值爲 int
 | 
				
			||||||
 | 
					    progress=$(echo "${response%???}" | grep -o '"progress": [0-9]*' | awk '{print $2}')
 | 
				
			||||||
 | 
					    echo "progress: $progress"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 如果進度爲 100, 訓練完成, 跳出循環
 | 
				
			||||||
 | 
					    [[ $progress -eq 100 ]] && { echo "訓練完成"; break; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 休眠 5 秒
 | 
				
			||||||
 | 
					    sleep 5
 | 
				
			||||||
 | 
					done
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					exit_service "測試結束, 全部通過, 用費時間: $(($(date +%s) - $start_time)) 秒"
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user