Files
ai/routers/servers.go
2023-05-25 09:39:03 +08:00

113 lines
3.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package routers
import (
"encoding/json"
"fmt"
"io/ioutil"
"main/configs"
"main/models"
"main/utils"
"net/http"
"github.com/google/uuid"
"github.com/gorilla/mux"
)
func ServersGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
var server_list []models.Server
db := configs.ORMDB()
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&server_list)
for _, server := range server_list {
// 驗證服務器狀態
resp, err := http.Get(fmt.Sprintf("http://%s:%d/docs", server.IP, server.Port))
if err != nil || resp.StatusCode != http.StatusOK {
server.Status = "異常"
} else {
server.Status = "正常"
}
// 讀取模型信息
resp, err = http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port))
if err != nil || resp.StatusCode != http.StatusOK {
} else {
var models []map[string]interface{}
body, _ := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err := json.Unmarshal(body, &models); err != nil {
server.Models = []map[string]interface{}{}
}
server.Models = models
}
listview.List = append(listview.List, server)
}
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w)
}
func ServersPost(w http.ResponseWriter, r *http.Request) {
var server models.Server
// 獲取參數
body, _ := ioutil.ReadAll(r.Body)
defer r.Body.Close()
if err := json.Unmarshal(body, &server); err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
// 如果不指定類型,禁止創建服務器, 必須指定類型:訓練|推理
if server.Type != "訓練" && server.Type != "推理" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("必須指定類型:訓練|推理"))
return
}
// 如果不指定名稱則使用uuid生成隨機名稱
if server.Name == "" {
server.Name = uuid.New().String()
}
// 如果不指定 port則使用默認 port
if server.Port == 0 {
server.Port = 7860
}
// 如果不指定IP則自動創建新服務器
if server.IP == "" {
// TODO: 創建新服務器
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server))
return
}
configs.ORMDB().Create(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server))
}
func ServersItemGet(w http.ResponseWriter, r *http.Request) {
server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
configs.ORMDB().First(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server))
}
func ServersItemPatch(w http.ResponseWriter, r *http.Request) {
server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
configs.ORMDB().First(&server)
// TODO: update server
configs.ORMDB().Save(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server))
}
func ServersItemDelete(w http.ResponseWriter, r *http.Request) {
server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
configs.ORMDB().Delete(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server))
}