Files
ai/routers/models.go
2023-05-28 09:50:01 +08:00

213 lines
5.2 KiB
Go

package routers
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"main/configs"
"main/models"
"main/utils"
"net/http"
"strconv"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
var manager = models.NewWebSocketManager()
// 獲取模型列表
func ModelsGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
var model_list []models.Model
db := configs.ORMDB()
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&model_list)
for _, model := range model_list {
listview.List = append(listview.List, model)
}
db.Model(&models.Model{}).Count(&listview.Total)
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w)
}
// 創建模型(訓練新模型)
func ModelsPost(w http.ResponseWriter, r *http.Request) {
models.AccountRead(w, r, func(account *models.Account) {
fmt.Println(account)
// TODO: 判斷權限(是否可以創建)
// 創建模型
var model models.Model
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &model); err != nil {
log.Println(err)
return
}
if model.Name == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("模型名稱不能為空"))
return
}
if model.Type == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("模型類型不能為空(recommend|lora|ckp|hyper|ti)"))
return
}
if model.TriggerWords == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("觸發詞不能為空"))
return
}
if model.BaseModel == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("基礎模型不能為空(SD1.5|SD2)"))
return
}
if model.Epochs <= 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("訓練輪數不能小於0"))
return
}
if model.Tags == nil {
model.Tags = []string{}
}
model.UserID = account.ID
model.Status = "initial"
if err := configs.ORMDB().Create(&model).Error; err != nil {
log.Println(err)
return
}
// 直接提交訓練任務
go model.Train()
// 返回創建的模型
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
})
}
// 獲取模型詳情
func ModelItemGet(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" {
vars := mux.Vars(r)
id, _ := strconv.Atoi(vars["id"])
var model = models.Model{ID: id}
if err := configs.ORMDB().Take(&model, id).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
defer conn.Close()
wsid := manager.AddConnection(conn)
defer manager.RemoveConnection(wsid)
for {
_, msg, err := conn.ReadMessage()
if err != nil {
log.Println(err)
return
}
log.Println(string(msg))
if string(msg) == "close" {
break
}
}
return
}
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(err.Error()))
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}
// 更新模型
func ModelItemPatch(w http.ResponseWriter, r *http.Request) {
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(err.Error()))
return
}
// 取出更新数据
var model_new models.Model
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &model_new); err != nil {
log.Println(err)
return
}
// 字段不爲空且不等於原始數據時更新
if model_new.Name != "" && model_new.Name != model.Name {
model.Name = model_new.Name
}
if model_new.Type != "" && model_new.Type != model.Type {
model.Type = model_new.Type
}
if model_new.Status != "" && model_new.Status != model.Status {
model.Status = model_new.Status
// 如果狀態被改變爲 ready, 將模型發送到訓練隊列
if model.Status == "ready" {
model.Status = "training"
go model.Train()
}
}
if model_new.Image != "" && model_new.Image != model.Image {
model.Image = model_new.Image
}
// 執行更新
if err := configs.ORMDB().Save(&model).Error; err != nil {
log.Println(err)
return
}
// 返回更新後的數據
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}
// 刪除模型
func ModelItemDelete(w http.ResponseWriter, r *http.Request) {
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}