Files
ai/routers/models.go
2023-05-12 14:16:27 +08:00

131 lines
3.2 KiB
Go

package routers
import (
"encoding/json"
"io/ioutil"
"log"
"main/models"
"main/utils"
"net/http"
"strconv"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
var manager = models.NewWebSocketManager()
func ModelsGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
listview.List = models.QueryModels(listview.Page, listview.PageSize)
listview.Total = models.CountModels()
listview.Next = listview.Page*listview.PageSize < listview.Total
listview.WriteJSON(w)
}
func ModelsPost(w http.ResponseWriter, r *http.Request) {
var model models.Model
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &model); err != nil {
log.Println(err)
return
}
model.Create()
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}
func ModelItemGet(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" {
vars := mux.Vars(r)
id, _ := strconv.Atoi(vars["id"])
model := models.QueryModel(id)
if model.ID == 0 {
w.WriteHeader(http.StatusNotFound)
return
}
upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
defer conn.Close()
wsid := manager.AddConnection(conn)
defer manager.RemoveConnection(wsid)
for {
_, msg, err := conn.ReadMessage()
if err != nil {
log.Println(err)
return
}
log.Println(string(msg))
if string(msg) == "close" {
break
}
}
return
}
model := models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
model.Get()
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}
func ModelItemPatch(w http.ResponseWriter, r *http.Request) {
// 取出原始数据
model := models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := model.Get(); err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
// 取出更新数据
var model_new models.Model
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &model_new); err != nil {
log.Println(err)
return
}
// 字段不爲空且不等於原始數據時更新
if model_new.Name != "" && model_new.Name != model.Name {
model.Name = model_new.Name
}
if model_new.Type != "" && model_new.Type != model.Type {
model.Type = model_new.Type
}
if model_new.Status != "" && model_new.Status != model.Status {
model.Status = model_new.Status
// 如果狀態被改變爲 ready, 將模型發送到訓練隊列
if model.Status == "ready" {
model.SendToTrain()
}
}
if model_new.Image != "" && model_new.Image != model.Image {
model.Image = model_new.Image
}
model.Update()
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}
func ModelItemDelete(w http.ResponseWriter, r *http.Request) {
model := models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
model.Delete()
w.WriteHeader(http.StatusNoContent)
}