Files
ai/routers/servers.go
2023-06-05 15:50:47 +08:00

138 lines
3.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package routers
import (
"encoding/json"
"fmt"
"io/ioutil"
"main/configs"
"main/models"
"main/utils"
"net/http"
"github.com/google/uuid"
"github.com/gorilla/mux"
)
func ServersGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
var server_list []models.Server
db := configs.ORMDB()
// 獲取服務器總數
db.Model(&models.Server{}).Count(&listview.Total)
// 獲取服務器列表
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&server_list)
for _, server := range server_list {
server.CheckStatus() // 驗證服務器狀態
//// 讀取模型信息
resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port))
if err != nil || resp.StatusCode != http.StatusOK {
server.Models = []map[string]interface{}{}
} else {
var models []map[string]interface{}
body, _ := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err := json.Unmarshal(body, &models); err != nil {
server.Models = []map[string]interface{}{}
}
server.Models = models
}
listview.List = append(listview.List, server)
}
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w)
}
func ServersPost(w http.ResponseWriter, r *http.Request) {
var server models.Server
// 獲取參數
body, err := ioutil.ReadAll(r.Body)
if err != nil {
fmt.Println("獲取數據失敗:", err)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
defer r.Body.Close()
// 解碼JSON
if err := json.Unmarshal(body, &server); err != nil {
fmt.Println("解碼JSON失敗:", err)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
// 如果不指定類型,禁止創建服務器, 必須指定類型:訓練|推理
if server.Type != "訓練" && server.Type != "推理" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("必須指定類型:訓練|推理"))
return
}
// 如果不指定名稱則使用uuid生成隨機名稱
if server.Name == "" {
server.Name = uuid.New().String()
}
// 如果不指定 port則使用默認 port
if server.Port <= 0 {
server.Port = 7860
}
// 如果不指定IP則自動創建新服務器
if server.IP == "" {
// TODO: 創建新服務器
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("必須指定IP, 因爲當前禁止自動創建服務器"))
return
}
// 檢查服務器是否已經存在
var count int64
configs.ORMDB().Model(&models.Server{}).Where("ip = ?", server.IP).Count(&count)
if count > 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("服務器已經存在"))
return
}
// 檢查服務器狀態是否正常
err = server.CheckStatus()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("服務器狀態錯誤:" + err.Error()))
return
}
// 創建服務器
configs.ORMDB().Create(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server))
}
func ServersItemGet(w http.ResponseWriter, r *http.Request) {
server := models.Server{ID: mux.Vars(r)["id"]}
configs.ORMDB().First(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server))
}
func ServersItemPatch(w http.ResponseWriter, r *http.Request) {
server := models.Server{ID: mux.Vars(r)["id"]}
configs.ORMDB().First(&server)
// TODO: update server
configs.ORMDB().Save(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server))
}
func ServersItemDelete(w http.ResponseWriter, r *http.Request) {
server := models.Server{ID: mux.Vars(r)["id"]}
configs.ORMDB().Delete(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server))
}