Files
ai/routers/models.go
2023-05-16 01:51:41 +08:00

166 lines
4.2 KiB
Go

package routers
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"main/configs"
"main/models"
"main/utils"
"net/http"
"strconv"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
var manager = models.NewWebSocketManager()
// 獲取模型列表
func ModelsGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
var model_list []models.Model
db := configs.ORMDB()
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&model_list)
for _, model := range model_list {
listview.List = append(listview.List, model)
}
db.Model(&models.Model{}).Count(&listview.Total)
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w)
}
// 創建模型
func ModelsPost(w http.ResponseWriter, r *http.Request) {
models.AccountRead(w, r, func(account *models.Account) {
fmt.Println(account)
// TODO: 判斷權限(是否可以創建)
// 創建模型
var model models.Model
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &model); err != nil {
log.Println(err)
return
}
if err := configs.ORMDB().Create(&model).Error; err != nil {
log.Println(err)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
})
}
func ModelItemGet(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" {
vars := mux.Vars(r)
id, _ := strconv.Atoi(vars["id"])
var model = models.Model{ID: id}
if err := configs.ORMDB().Take(&model, id).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
defer conn.Close()
wsid := manager.AddConnection(conn)
defer manager.RemoveConnection(wsid)
for {
_, msg, err := conn.ReadMessage()
if err != nil {
log.Println(err)
return
}
log.Println(string(msg))
if string(msg) == "close" {
break
}
}
return
}
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}
func ModelItemPatch(w http.ResponseWriter, r *http.Request) {
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(err.Error()))
return
}
// 取出更新数据
var model_new models.Model
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &model_new); err != nil {
log.Println(err)
return
}
// 字段不爲空且不等於原始數據時更新
if model_new.Name != "" && model_new.Name != model.Name {
model.Name = model_new.Name
}
if model_new.Type != "" && model_new.Type != model.Type {
model.Type = model_new.Type
}
if model_new.Status != "" && model_new.Status != model.Status {
model.Status = model_new.Status
// 如果狀態被改變爲 ready, 將模型發送到訓練隊列
if model.Status == "ready" {
model.Status = "training"
go model.Train()
}
}
if model_new.Image != "" && model_new.Image != model.Image {
model.Image = model_new.Image
}
// 執行更新
if err := configs.ORMDB().Save(&model).Error; err != nil {
log.Println(err)
return
}
// 返回更新後的數據
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}
func ModelItemDelete(w http.ResponseWriter, r *http.Request) {
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}