From 4afb719d134c1a00bb4e9cf268b3b8fa12b19b44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A1=9C=E8=8F=AF?= Date: Sun, 2 Jul 2023 04:22:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/Model.go | 66 ++++++++++++----------- models/server.go | 61 +++++++++++++++++++-- test.sh | 136 +++++++++++++++++++++++++---------------------- 3 files changed, 165 insertions(+), 98 deletions(-) diff --git a/models/Model.go b/models/Model.go index 7fed8c9..8e3ff5e 100644 --- a/models/Model.go +++ b/models/Model.go @@ -11,10 +11,10 @@ import ( "net/url" "os" "path/filepath" - "strconv" "time" "encoding/base64" + "encoding/json" "image/png" "github.com/chai2010/webp" @@ -22,26 +22,27 @@ import ( ) type Model struct { - ID int `json:"id" gorm:"primary_key"` // 模型ID - Name string `json:"name"` // 模型名稱 - Info string `json:"info"` // 模型描述 - Type string `json:"type"` // 模型類型(lora|ckp|hyper|ti) - TriggerWords string `json:"trigger_words"` // 觸發詞 - BaseModel string `json:"base_model"` // 基礎模型(SD1.5|SD2) - ModelPath string `json:"model_path"` // 模型路徑(實際存放在服務器上的路徑) - Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error) - Progress int `json:"progress"` // (0-100) - Image string `json:"image"` // 封面圖片實際地址 - Hash string `json:"hash"` // 模型哈希值 - Epochs int `json:"epochs"` // 訓練步數 - LearningRate float32 `json:"learning_rate"` // 學習率(0.000005) - Tags TagList `json:"tags"` // 模型標籤(標籤名數組) - UserID int `json:"user_id"` // 模型的所有者 - DatasetID int `json:"dataset_id"` // 模型所使用的數據集ID - ServerID string `json:"server_id"` // 模型所在服務器(訓練機或推理機) - Stars StarList `json:"stars"` // 模型的收藏者 - CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` - UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + ID int `json:"id" gorm:"primary_key"` // 模型ID + Name string `json:"name"` // 模型名稱 + ModelCheckpoint string `json:"model_checkpoint"` // 模型檢查點 + Info string `json:"info"` // 模型描述 + Type string `json:"type"` // 模型類型(lora|ckp|hyper|ti) + TriggerWords string `json:"trigger_words"` // 觸發詞 + BaseModel string `json:"base_model"` // 基礎模型(SD1.5|SD2) + ModelPath string `json:"model_path"` // 模型路徑(實際存放在服務器上的路徑) + Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error) + Progress int `json:"progress"` // (0-100) + Image string `json:"image"` // 封面圖片實際地址 + Hash string `json:"hash"` // 模型哈希值(sha256) + Epochs int `json:"epochs"` // 訓練步數 + LearningRate float32 `json:"learning_rate"` // 學習率(0.000005) + Tags TagList `json:"tags"` // 模型標籤(標籤名數組) + UserID int `json:"user_id"` // 模型的所有者 + DatasetID int `json:"dataset_id"` // 模型所使用的數據集ID + ServerID string `json:"server_id"` // 模型所在服務器(訓練機或推理機) + Stars StarList `json:"stars"` // 模型的收藏者 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } func init() { @@ -51,6 +52,10 @@ func init() { log.Println(err) } } + // 清除所有hash长度小于32的模型 + configs.ORMDB().Where("length(hash) < 32").Delete(&Model{}) + // 清除所有type为空的模型 + configs.ORMDB().Where("type = ?", "").Delete(&Model{}) } // 从数据库加载 @@ -66,17 +71,16 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) { if model.ServerID == "" { log.Println("模型未部署到推理機, 开始部署模型") - // 寻找一台就绪的推理机, 且已部署模型目标模型 - if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("models LIKE ?", "%"+strconv.Itoa(model.ID)+"%").First(&server).Error; err != nil { - // 寻找一台就绪的推理机, 且模型位置仍有空余 - if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("length(models) < ?", 5).First(&server).Error; err != nil { - log.Println("创建一台新的推理机: 当前禁止创建新服务器") - return - } - // 上传目标模型到推理机 - log.Println("上传模型到推理机: 当前禁止上传模型") + // 寻找一台就绪的且模型位置仍有空余的推理机 + if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("length(models) < ?", 5).First(&server).Error; err != nil { + log.Println("创建一台新的推理机: 当前禁止创建新服务器") + return } + // 打印为格式化的json + data, _ := json.MarshalIndent(server, "", " ") + fmt.Println(string(data)) + //var form = struct { // Components []struct { // ID int `json:"id"` @@ -106,7 +110,7 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) { //} // 记录到推理机 - server.Models = append(server.Models, strconv.Itoa(model.ID)) + server.Models = append(server.Models, model.ID) configs.ORMDB().Save(&server) // 记录到模型 diff --git a/models/server.go b/models/server.go index 854cb41..8ff309b 100644 --- a/models/server.go +++ b/models/server.go @@ -18,7 +18,7 @@ import ( cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312" ) -type ModelList []string +type ModelList []int func (list *ModelList) Scan(value interface{}) error { return json.Unmarshal(value.([]byte), list) @@ -75,9 +75,62 @@ func init() { // 检查默认服务器是否存在, 不存在则添加 func InitDefaultServer() (err error) { - if err = configs.ORMDB().Where("id = ?", "default").First(&Server{}).Error; err != nil { - server := Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"} - err = configs.ORMDB().Create(&server).Error + var server Server + if err = configs.ORMDB().Where("id = ?", "default").First(&server).Error; err != nil { + server = Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"} + if err = configs.ORMDB().Create(&server).Error; err != nil { + return fmt.Errorf("创建默认服务器失败: %v", err) + } + } + // 初始化服务器中的模型列表 + if err = server.InitModels(); err != nil { + return fmt.Errorf("初始化服务器中的模型列表失败: %v", err) + } + return +} + +// 初始化服务器中的模型列表 +func (server *Server) InitModels() (err error) { + // 获取服务器中的模型列表 + resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)) + if err != nil { + return fmt.Errorf("获取服务器中的模型列表失败: %v", err) + } + defer resp.Body.Close() + + // 解码JSON (数组) + var data []map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return err + } + + // 从数据库检查此模型hash是否存在 + for _, item := range data { + var model Model + if err := configs.ORMDB().Where("hash = ?", item["sha256"].(string)).First(&model).Error; err != nil { + // 不存在则添加 + model = Model{ + Name: item["model_name"].(string), + Hash: item["sha256"].(string), + ModelCheckpoint: item["title"].(string), + ModelPath: item["filename"].(string), // TODO: 下载到本地 + ServerID: server.ID, + Type: "ckp", + } + // TODO: 下载到本地 + // 添加到数据库 + if err := configs.ORMDB().Create(&model).Error; err != nil { + return fmt.Errorf("添加模型到数据库失败: %v", err) + } + // 添加到模型列表 + server.Models = append(server.Models, model.ID) + + } + } + + // 更新数据库 + if err := configs.ORMDB().Save(&server).Error; err != nil { + return fmt.Errorf("更新数据库失败: %v", err) } return } diff --git a/test.sh b/test.sh index f17c917..d099155 100755 --- a/test.sh +++ b/test.sh @@ -47,73 +47,83 @@ message "$response" "登錄" session_id=$(echo "$response" | head -n -1 | grep -o '"id": "[^"]*' | cut -d '"' -f 4) #echo "session_id: $session_id" +# 获取模型列表 +response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) +message "$response" "模型列表" true -# 上傳圖片 (POST /api/images) -response=$(curl -X POST -H "Content-Type: multipart/form-data" -F "file=@./data/test.jpeg" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/images) -message "$response" "上傳圖片" true - -# 臨時退出 -exit_service "測試結束, 全部通過, 用費時間: $(($(date +%s) - $start_time)) 秒" - -# 創建數據集, 應當在cookie中攜帶session_id (POST /api/datasets) -response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","description":"test"}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/datasets) -message "$response" "創建數據集" - - -# 取數據集id的值, 值爲 int -dataset_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}') -#echo "dataset_id: $dataset_id" - - -# 獲取數據集列表 (GET /api/datasets) -response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/datasets) -message "$response" "數據集列表" - - -# 修改數據集, images 中增加 url (PATCH /api/datasets/:id) -response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["https://img.gameui.net/article-7258-1677745322000@1x456.webp","https://img.gameui.net/article-6477-1682109454000@1x456.webp"]}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/datasets/$dataset_id) -message "$response" "修改數據集" - - -# 添加服務器 (POST /api/servers) -response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"GPU-T4","type":"訓練","ip":"106.15.192.42","port":7860}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/servers) -message "$response" "添加服務器" - - -# 服務器列表 -response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/servers) -message "$response" "服務器列表" - - -# 創建模型訓練任務 (POST /api/models) -response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"dreambooth","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) -message "$response" "創建模型訓練任務" true - - -# 取模型id的值, 值爲 int +# 使用第一个模型的id, 执行推理生成图像 model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}') -#echo "model_id: $model_id" +echo "model_id: $model_id" + +response=$(curl -X POST -H "Content-Type: application/json" -d '{"model_id":'$model_id',"text":"miao~"}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/images) +message "$response" "推理生成图像" true + +## 上傳圖片 (POST /api/images) +#response=$(curl -X POST -H "Content-Type: multipart/form-data" -F "file=@./data/test.jpeg" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/images) +#message "$response" "上傳圖片" +# +## 臨時退出 +##exit_service "測試結束, 全部通過, 用費時間: $(($(date +%s) - $start_time)) 秒" +# +## 創建數據集, 應當在cookie中攜帶session_id (POST /api/datasets) +#response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","description":"test"}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/datasets) +#message "$response" "創建數據集" +# +# +## 取數據集id的值, 值爲 int +#dataset_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}') +##echo "dataset_id: $dataset_id" +# +# +## 獲取數據集列表 (GET /api/datasets) +#response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/datasets) +#message "$response" "數據集列表" +# +# +## 修改數據集, images 中增加 url (PATCH /api/datasets/:id) +#response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["https://img.gameui.net/article-7258-1677745322000@1x456.webp","https://img.gameui.net/article-6477-1682109454000@1x456.webp"]}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/datasets/$dataset_id) +#message "$response" "修改數據集" -# 循環獲取模型訓練進度, 直到訓練完成 -while true; do - # 獲取模型訓練進度 (GET /api/models/:id) - response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models/$model_id) - progress=$(echo "${response%???}" | grep -o '"progress": [0-9]*' | awk '{print $2}') - status=$(echo "${response%???}" | grep -o '"status": "[^"]*' | cut -d '"' -f 4) - message "$response" "獲取模型訓練進度 $progress% $status" - # 如果進度爲 100, 訓練完成, 跳出循環 - [[ $progress -eq 100 ]] && { echo "訓練完成"; break; } - # 測試訓練時間不超過10秒, 超過則退出 - [[ $(($(date +%s) - $start_time)) -gt 10 ]] && exit_service "訓練時間超過20秒" - # 休眠 3 秒 - sleep 3 -done - - -## 模型列表 (GET /api/models) -#response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) -#message "$response" "模型列表" +## 添加服務器 (POST /api/servers) +#response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"GPU-T4","type":"訓練","ip":"106.15.192.42","port":7860}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/servers) +#message "$response" "添加服務器" +# +# +## 服務器列表 +#response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/servers) +#message "$response" "服務器列表" +# +# +## 創建模型訓練任務 (POST /api/models) +#response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"dreambooth","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) +#message "$response" "創建模型訓練任務" true +# +# +## 取模型id的值, 值爲 int +#model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}') +##echo "model_id: $model_id" +# +# +## 循環獲取模型訓練進度, 直到訓練完成 +#while true; do +# # 獲取模型訓練進度 (GET /api/models/:id) +# response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models/$model_id) +# progress=$(echo "${response%???}" | grep -o '"progress": [0-9]*' | awk '{print $2}') +# status=$(echo "${response%???}" | grep -o '"status": "[^"]*' | cut -d '"' -f 4) +# message "$response" "獲取模型訓練進度 $progress% $status" +# # 如果進度爲 100, 訓練完成, 跳出循環 +# [[ $progress -eq 100 ]] && { echo "訓練完成"; break; } +# # 測試訓練時間不超過10秒, 超過則退出 +# [[ $(($(date +%s) - $start_time)) -gt 10 ]] && exit_service "訓練時間超過20秒" +# # 休眠 3 秒 +# sleep 3 +#done +# +# +### 模型列表 (GET /api/models) +##response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models) +##message "$response" "模型列表" exit_service "測試結束, 全部通過, 用費時間: $(($(date +%s) - $start_time)) 秒"