package routers import ( "encoding/json" "fmt" "io/ioutil" "log" "main/configs" "main/models" "main/utils" "net/http" "strconv" "github.com/gorilla/mux" "github.com/gorilla/websocket" ) var manager = models.NewWebSocketManager() // 獲取模型列表 func ModelsGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) var model_list []models.Model db := configs.ORMDB() db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&model_list) for _, model := range model_list { listview.List = append(listview.List, model) } db.Model(&models.Model{}).Count(&listview.Total) listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } // 創建模型(訓練新模型) func ModelsPost(w http.ResponseWriter, r *http.Request) { models.AccountRead(w, r, func(account *models.Account) { fmt.Println(account) // 創建模型 var model models.Model body, err := ioutil.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error())) return } defer r.Body.Close() if err = json.Unmarshal(body, &model); err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error())) return } if model.Name == "" { model.Name = utils.RandomString(8) } if model.Type == "" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("模型類型不能為空(recommend|lora|ckp|hyper|ti)")) return } if model.TriggerWords == "" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("觸發詞不能為空")) return } if model.BaseModel == "" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("基礎模型不能為空(SD1.5|SD2)")) return } if model.Epochs <= 0 { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("訓練輪數不能小於0")) return } if model.Tags == nil { model.Tags = []string{} } model.UserID = account.ID model.Status = "initial" if err := configs.ORMDB().Create(&model).Error; err != nil { log.Println(err) return } // 直接提交訓練任務 // go model.Train() // 返回創建的模型 w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(model)) }) } // 獲取模型詳情 func ModelItemGet(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") == "websocket" { vars := mux.Vars(r) id, _ := strconv.Atoi(vars["id"]) var model = models.Model{ID: id} if err := configs.ORMDB().Take(&model, id).Error; err != nil { w.WriteHeader(http.StatusNotFound) return } upgrader := websocket.Upgrader{} conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println(err) return } defer conn.Close() wsid := manager.AddConnection(conn) defer manager.RemoveConnection(wsid) for { _, msg, err := conn.ReadMessage() if err != nil { log.Println(err) return } log.Println(string(msg)) if string(msg) == "close" { break } } return } var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil { w.WriteHeader(http.StatusNotFound) w.Write([]byte(err.Error())) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(model)) } // 更新模型 func ModelItemPatch(w http.ResponseWriter, r *http.Request) { var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil { w.WriteHeader(http.StatusNotFound) w.Write([]byte(err.Error())) return } // 取出更新数据 var model_new models.Model body, err := ioutil.ReadAll(r.Body) if err != nil { log.Println(err) return } defer r.Body.Close() if err = json.Unmarshal(body, &model_new); err != nil { log.Println(err) return } // 字段不爲空且不等於原始數據時更新 if model_new.Name != "" && model_new.Name != model.Name { model.Name = model_new.Name } if model_new.Type != "" && model_new.Type != model.Type { model.Type = model_new.Type } if model_new.Status != "" && model_new.Status != model.Status { model.Status = model_new.Status // 如果狀態被改變爲 ready, 將模型發送到訓練隊列 if model.Status == "ready" { model.Status = "training" go model.Train() } } if model_new.Image != "" && model_new.Image != model.Image { model.Image = model_new.Image } // 執行更新 if err := configs.ORMDB().Save(&model).Error; err != nil { log.Println(err) return } // 返回更新後的數據 w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(model)) } // 刪除模型 func ModelItemDelete(w http.ResponseWriter, r *http.Request) { var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil { w.WriteHeader(http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(model)) }