246 lines
6.0 KiB
Go
246 lines
6.0 KiB
Go
package routers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"main/configs"
|
|
"main/models"
|
|
"main/utils"
|
|
"net/http"
|
|
"strconv"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
var manager = models.NewWebSocketManager()
|
|
|
|
// 獲取模型列表
|
|
func ModelsGet(w http.ResponseWriter, r *http.Request) {
|
|
var listview models.ListView
|
|
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
|
|
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
|
|
|
|
var model_list []models.Model
|
|
db := configs.ORMDB()
|
|
|
|
// 按照 user_id 篩選
|
|
if user_id := utils.ParamInt(r.URL.Query().Get("user_id"), 0); user_id > 0 {
|
|
db = db.Where("user_id = ?", user_id)
|
|
}
|
|
|
|
// 按照 star 篩選
|
|
if star := utils.ParamInt(r.URL.Query().Get("star"), 0); star > 0 {
|
|
db = db.Where("stars LIKE ?", "%"+strconv.Itoa(star)+"%")
|
|
}
|
|
|
|
// 按照 name 模糊搜索
|
|
if name := r.URL.Query().Get("name"); name != "" {
|
|
db = db.Where("name LIKE ?", "%"+name+"%")
|
|
}
|
|
|
|
// 按照 type 篩選
|
|
if model_type := r.URL.Query().Get("type"); model_type != "" {
|
|
db = db.Where("type = ?", model_type)
|
|
}
|
|
|
|
// 按照 status 篩選
|
|
if status := r.URL.Query().Get("status"); status != "" {
|
|
db = db.Where("status = ?", status)
|
|
}
|
|
|
|
// 按照 tag 篩選
|
|
if tag := r.URL.Query().Get("tag"); tag != "" {
|
|
db = db.Where("tags LIKE ?", "%"+tag+"%")
|
|
}
|
|
|
|
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&model_list)
|
|
for _, model := range model_list {
|
|
listview.List = append(listview.List, model)
|
|
}
|
|
db.Model(&models.Model{}).Count(&listview.Total)
|
|
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
|
|
listview.WriteJSON(w)
|
|
}
|
|
|
|
// 創建模型(訓練新模型)
|
|
func ModelsPost(w http.ResponseWriter, r *http.Request) {
|
|
models.AccountRead(w, r, func(account *models.Account) {
|
|
fmt.Println(account)
|
|
|
|
// 創建模型
|
|
var model models.Model
|
|
body, err := ioutil.ReadAll(r.Body)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
|
|
if err = json.Unmarshal(body, &model); err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
if model.Name == "" {
|
|
model.Name = utils.RandomString(8)
|
|
}
|
|
|
|
if model.Type == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("模型類型不能為空(recommend|lora|ckp|hyper|ti)"))
|
|
return
|
|
}
|
|
|
|
if model.TriggerWords == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("觸發詞不能為空"))
|
|
return
|
|
}
|
|
|
|
if model.BaseModel == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("基礎模型不能為空(SD1.5|SD2)"))
|
|
return
|
|
}
|
|
|
|
if model.Epochs <= 0 {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("訓練輪數不能小於0"))
|
|
return
|
|
}
|
|
|
|
if model.Tags == nil {
|
|
model.Tags = []string{}
|
|
}
|
|
|
|
model.UserID = account.ID
|
|
model.Status = "initial"
|
|
|
|
if err := configs.ORMDB().Create(&model).Error; err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 直接提交訓練任務
|
|
// go model.Train()
|
|
|
|
// 返回創建的模型
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(model))
|
|
})
|
|
}
|
|
|
|
// 獲取模型詳情
|
|
func ModelItemGet(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("Upgrade") == "websocket" {
|
|
vars := mux.Vars(r)
|
|
id, _ := strconv.Atoi(vars["id"])
|
|
|
|
var model = models.Model{ID: id}
|
|
if err := configs.ORMDB().Take(&model, id).Error; err != nil {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
upgrader := websocket.Upgrader{}
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
wsid := manager.AddConnection(conn)
|
|
defer manager.RemoveConnection(wsid)
|
|
for {
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
log.Println(string(msg))
|
|
if string(msg) == "close" {
|
|
break
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(model))
|
|
}
|
|
|
|
// 更新模型
|
|
func ModelItemPatch(w http.ResponseWriter, r *http.Request) {
|
|
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
// 取出更新数据
|
|
var model_new models.Model
|
|
body, err := ioutil.ReadAll(r.Body)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
if err = json.Unmarshal(body, &model_new); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 字段不爲空且不等於原始數據時更新
|
|
if model_new.Name != "" && model_new.Name != model.Name {
|
|
model.Name = model_new.Name
|
|
}
|
|
if model_new.Type != "" && model_new.Type != model.Type {
|
|
model.Type = model_new.Type
|
|
}
|
|
if model_new.Status != "" && model_new.Status != model.Status {
|
|
model.Status = model_new.Status
|
|
// 如果狀態被改變爲 ready, 將模型發送到訓練隊列
|
|
if model.Status == "ready" {
|
|
model.Status = "training"
|
|
go model.Train()
|
|
}
|
|
}
|
|
if model_new.Image != "" && model_new.Image != model.Image {
|
|
model.Image = model_new.Image
|
|
}
|
|
|
|
// 執行更新
|
|
if err := configs.ORMDB().Save(&model).Error; err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 返回更新後的數據
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(model))
|
|
}
|
|
|
|
// 刪除模型
|
|
func ModelItemDelete(w http.ResponseWriter, r *http.Request) {
|
|
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(model))
|
|
}
|