Files
ai/models/server.go
2023-06-21 15:40:00 +08:00

91 lines
2.2 KiB
Go

package models
import (
"database/sql/driver"
"encoding/json"
"fmt"
"main/configs"
"net/http"
"time"
)
type ModelList []string
func (list *ModelList) Scan(value interface{}) error {
return json.Unmarshal(value.([]byte), list)
}
func (list ModelList) Value() (driver.Value, error) {
return json.Marshal(list)
}
type Server struct {
ID string `json:"id" gorm:"primary_key"`
Name string `json:"name"`
Type string `json:"type"` // (训练|推理)
IP string `json:"ip"`
Port int `json:"port"`
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
UserName string `json:"username"`
Password string `json:"password"`
Models ModelList `json:"models"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// 获取所有服务器
func GetServers() (servers []Server, err error) {
err = configs.ORMDB().Find(&servers).Error
return
}
// 檢查服務器是否正常
func (server *Server) CheckStatus() error {
switch server.Type {
case "训练":
resp, err := http.Get(fmt.Sprintf("http://%s:%d/dreambooth/status", server.IP, server.Port))
if err != nil {
server.Status = "異常"
return err
}
defer resp.Body.Close()
// 解碼JSON
var data map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return err
}
// 解碼JSON
var current_state map[string]interface{}
if err := json.Unmarshal([]byte(data["current_state"].(string)), &current_state); err != nil {
return err
}
//log.Println("current_state:", current_state)
// 檢查服務器是否正常
if !current_state["active"].(bool) {
server.Status = "異常"
return fmt.Errorf("服務器狀態異常: active=false")
}
server.Status = "正常"
case "推理":
server.Status = "異常"
default:
server.Status = "異常"
}
// 檢查服務器是否正常
return nil
}
func init() {
configs.ORMDB().AutoMigrate(&Server{})
// 檢查所有服務器的狀態, 無效的服務器設置為異常
var servers []Server
configs.ORMDB().Find(&servers)
for _, server := range servers {
server.CheckStatus()
}
}