From c832e0f4e9c5b1b294063a488d3c30336f1d8bc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A1=9C=E8=8F=AF?= Date: Mon, 29 May 2023 13:48:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E7=88=B2=E6=89=8B=E5=8B=95=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=9C=8D=E5=8B=99=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/Server.go | 70 +++++++++++++++++++++------------------------- routers/servers.go | 35 ++++++++++++++++++++--- test.sh | 11 ++++++-- 3 files changed, 71 insertions(+), 45 deletions(-) diff --git a/models/Server.go b/models/Server.go index 598b026..b9df918 100644 --- a/models/Server.go +++ b/models/Server.go @@ -3,7 +3,6 @@ package models import ( "encoding/json" "fmt" - "log" "main/configs" "net/http" "time" @@ -23,52 +22,47 @@ type Server struct { UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } -func (server *Server) CheckStatus() (err error) { - // 不用類型的模型有不同的狀態檢查方式 - if server.Type == "train" { - } +func (server *Server) CheckStatus() error { + switch server.Type { + case "訓練": + resp, err := http.Get(fmt.Sprintf("http://%s:%d/dreambooth/status", server.IP, server.Port)) + if err != nil { + server.Status = "異常" + return err + } + defer resp.Body.Close() - resp, err := http.Get(fmt.Sprintf("http://%s:%d/status", server.IP, server.Port)) - if err != nil { - log.Println("服務器狀態異常", err) + // 解碼JSON + var data map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return err + } + + // 解碼JSON + var current_state map[string]interface{} + if err := json.Unmarshal([]byte(data["current_state"].(string)), ¤t_state); err != nil { + return err + } + //log.Println("current_state:", current_state) + + // 檢查服務器是否正常 + if !current_state["active"].(bool) { + server.Status = "異常" + return fmt.Errorf("服務器狀態異常: active=false") + } + server.Status = "正常" + case "推理": + server.Status = "異常" + default: server.Status = "異常" - return } - defer resp.Body.Close() - - // 解碼json - var data map[string]interface{} - if err = json.NewDecoder(resp.Body).Decode(&data); err != nil { - return - } - - log.Println("data:", data) // 檢查服務器是否正常 - if data["status"] != "ok" { - log.Println("服務器狀態異常", err) - server.Status = "異常" - return - } - - configs.ORMDB().Save(&server) - - // 檢查服務器是否正常 - return + return nil } func init() { configs.ORMDB().AutoMigrate(&Server{}) - - // 添加一個預設的訓練機 - configs.ORMDB().Create(&Server{ - Name: "GPU T4", - Type: "train", - IP: "106.15.192.42", - Port: 7860, - Status: "閒置", - }) - // 檢查所有服務器的狀態, 無效的服務器設置為異常 var servers []Server configs.ORMDB().Find(&servers) diff --git a/routers/servers.go b/routers/servers.go index 3d629b2..8f7c667 100644 --- a/routers/servers.go +++ b/routers/servers.go @@ -48,9 +48,18 @@ func ServersPost(w http.ResponseWriter, r *http.Request) { var server models.Server // 獲取參數 - body, _ := ioutil.ReadAll(r.Body) + body, err := ioutil.ReadAll(r.Body) + if err != nil { + fmt.Println("獲取數據失敗:", err) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } defer r.Body.Close() + + // 解碼JSON if err := json.Unmarshal(body, &server); err != nil { + fmt.Println("解碼JSON失敗:", err) w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error())) return @@ -69,18 +78,36 @@ func ServersPost(w http.ResponseWriter, r *http.Request) { } // 如果不指定 port,則使用默認 port - if server.Port == 0 { + if server.Port <= 0 { server.Port = 7860 } // 如果不指定IP,則自動創建新服務器 if server.IP == "" { // TODO: 創建新服務器 - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(utils.ToJSON(server)) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("必須指定IP, 因爲當前禁止自動創建服務器")) return } + // 檢查服務器是否已經存在 + var count int64 + configs.ORMDB().Model(&models.Server{}).Where("ip = ?", server.IP).Count(&count) + if count > 0 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("服務器已經存在")) + return + } + + // 檢查服務器狀態是否正常 + err = server.CheckStatus() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("服務器狀態錯誤:" + err.Error())) + return + } + + // 創建服務器 configs.ORMDB().Create(&server) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(server)) diff --git a/test.sh b/test.sh index 962b844..26103cb 100755 --- a/test.sh +++ b/test.sh @@ -68,9 +68,14 @@ response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["htt message "$response" "修改數據集" +# 添加服務器 (POST /api/servers) +response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"GPU-T4","type":"訓練","ip":"106.15.192.42","port":7860}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/servers) +message "$response" "添加服務器" + + # 服務器列表 response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/servers) -message "$response" "服務器列表" true +message "$response" "服務器列表" # 創建模型訓練任務 (POST /api/models) @@ -92,8 +97,8 @@ while true; do message "$response" "獲取模型訓練進度 $progress% $status" # 如果進度爲 100, 訓練完成, 跳出循環 [[ $progress -eq 100 ]] && { echo "訓練完成"; break; } - # 測試訓練時間不超過20秒, 超過則退出 - [[ $(($(date +%s) - $start_time)) -gt 20 ]] && exit_service "訓練時間超過20秒" + # 測試訓練時間不超過10秒, 超過則退出 + [[ $(($(date +%s) - $start_time)) -gt 10 ]] && exit_service "訓練時間超過20秒" # 休眠 3 秒 sleep 3 done