Files
ai/models/server.go
2023-07-02 04:22:24 +08:00

245 lines
7.6 KiB
Go

package models
import (
"database/sql/driver"
"encoding/json"
"fmt"
"io/ioutil"
"main/configs"
"net/http"
"path/filepath"
"time"
"gopkg.in/yaml.v2"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/errors"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312"
)
type ModelList []int
func (list *ModelList) Scan(value interface{}) error {
return json.Unmarshal(value.([]byte), list)
}
func (list ModelList) Value() (driver.Value, error) {
return json.Marshal(list)
}
type Server struct {
ID string `json:"id" gorm:"primary_key"`
Name string `json:"name"`
Type string `json:"type"` // (训练|推理)
IP string `json:"ip"`
Port int `json:"port"`
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
UserName string `json:"username"`
Password string `json:"password"`
Models ModelList `json:"models"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
var config = struct {
TencentCloud struct {
SecretId string `yaml:"SecretId"`
SecretKey string `yaml:"SecretKey"`
Region string `yaml:"Region"`
} `yaml:"TencentCloud"`
}{}
func init() {
configs.ORMDB().AutoMigrate(&Server{})
// 檢查所有服務器的狀態, 無效的服務器設置為異常
var servers []Server
configs.ORMDB().Find(&servers)
for _, server := range servers {
server.CheckStatus()
}
// 讀取配置文件
absPath, _ := filepath.Abs("./data/config.yaml")
configFile, err := ioutil.ReadFile(absPath)
if err != nil {
panic(fmt.Errorf("讀取配置文件失敗: %v", err))
}
if err := yaml.Unmarshal(configFile, &config); err != nil {
panic(fmt.Errorf("格式化配置文件失敗: %v", err))
}
// 初始化检查默认服务器
if err := InitDefaultServer(); err != nil {
panic(fmt.Errorf("初始化默认服务器失败: %v", err))
}
}
// 检查默认服务器是否存在, 不存在则添加
func InitDefaultServer() (err error) {
var server Server
if err = configs.ORMDB().Where("id = ?", "default").First(&server).Error; err != nil {
server = Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"}
if err = configs.ORMDB().Create(&server).Error; err != nil {
return fmt.Errorf("创建默认服务器失败: %v", err)
}
}
// 初始化服务器中的模型列表
if err = server.InitModels(); err != nil {
return fmt.Errorf("初始化服务器中的模型列表失败: %v", err)
}
return
}
// 初始化服务器中的模型列表
func (server *Server) InitModels() (err error) {
// 获取服务器中的模型列表
resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port))
if err != nil {
return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
}
defer resp.Body.Close()
// 解码JSON (数组)
var data []map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return err
}
// 从数据库检查此模型hash是否存在
for _, item := range data {
var model Model
if err := configs.ORMDB().Where("hash = ?", item["sha256"].(string)).First(&model).Error; err != nil {
// 不存在则添加
model = Model{
Name: item["model_name"].(string),
Hash: item["sha256"].(string),
ModelCheckpoint: item["title"].(string),
ModelPath: item["filename"].(string), // TODO: 下载到本地
ServerID: server.ID,
Type: "ckp",
}
// TODO: 下载到本地
// 添加到数据库
if err := configs.ORMDB().Create(&model).Error; err != nil {
return fmt.Errorf("添加模型到数据库失败: %v", err)
}
// 添加到模型列表
server.Models = append(server.Models, model.ID)
}
}
// 更新数据库
if err := configs.ORMDB().Save(&server).Error; err != nil {
return fmt.Errorf("更新数据库失败: %v", err)
}
return
}
// 创建一台新服务器
func NewServer(server_type string) (server Server, err error) {
// 调用 API 创建一台新服务器(通過腾讯云API創建服務器)
client, err := cvm.NewClient(common.NewCredential(config.TencentCloud.SecretId, config.TencentCloud.SecretKey), config.TencentCloud.Region, profile.NewClientProfile())
if err != nil {
return server, fmt.Errorf("初始化騰訊雲SDK客戶端失敗: %v", err)
}
// 实例化一个请求对象, 指定啓動模板, 以創建指定規格的服務器
request := cvm.NewRunInstancesRequest()
request.LaunchTemplate = &cvm.LaunchTemplate{LaunchTemplateId: common.StringPtr("lt-ks6y5evh")}
response, err := client.RunInstances(request)
if _, ok := err.(*errors.TencentCloudSDKError); ok {
return server, fmt.Errorf("已返回 API 错误: %v", err)
}
if err != nil {
return server, fmt.Errorf("运行实例失败: %v", err)
}
fmt.Println("創建服務器成功:", response.Response.InstanceIdSet[0])
// 获取服务器信息
var get_server_info = func(InstanceIdSet *string) (server Server, err error) {
response2, err := client.DescribeInstances(cvm.NewDescribeInstancesRequest())
if err != nil {
return server, fmt.Errorf("獲取實例詳情失敗: %v", err)
}
for _, instance := range response2.Response.InstanceSet {
if *instance.InstanceId != *InstanceIdSet {
server.ID = *instance.InstanceId
server.Name = *instance.InstanceName
server.IP = *instance.PublicIpAddresses[0]
server.Port = 7890
server.Status = *instance.InstanceState
configs.ORMDB().Create(&server)
return server, nil
}
}
return server, fmt.Errorf("未取得實例詳情: %v", err)
}
// 等待服务器创建完成
return get_server_info(response.Response.InstanceIdSet[0])
}
// 注销服务器
func (server *Server) Delete() error {
client, err := cvm.NewClient(common.NewCredential(config.TencentCloud.SecretId, config.TencentCloud.SecretKey), config.TencentCloud.Region, profile.NewClientProfile())
if err != nil {
return fmt.Errorf("初始化騰訊雲SDK客戶端失敗: %v", err)
}
request := cvm.NewTerminateInstancesRequest()
request.InstanceIds = []*string{common.StringPtr(server.ID)}
response, err := client.TerminateInstances(request)
if _, ok := err.(*errors.TencentCloudSDKError); ok {
return fmt.Errorf("已返回 API 错误: %v", err)
}
if err != nil {
return fmt.Errorf("註銷實例失敗: %v", err)
}
// 從列表中刪除服務器
configs.ORMDB().Delete(&server)
fmt.Println("註銷服務器成功:", server.ID, response.Response)
return nil
}
// 檢查服務器是否正常
func (server *Server) CheckStatus() error {
switch server.Type {
case "训练":
resp, err := http.Get(fmt.Sprintf("http://%s:%d/dreambooth/status", server.IP, server.Port))
if err != nil {
server.Status = "異常"
return err
}
defer resp.Body.Close()
// 解碼JSON
var data map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return err
}
// 解碼JSON
var current_state map[string]interface{}
if err := json.Unmarshal([]byte(data["current_state"].(string)), &current_state); err != nil {
return err
}
//log.Println("current_state:", current_state)
// 檢查服務器是否正常
if !current_state["active"].(bool) {
server.Status = "異常"
return fmt.Errorf("服務器狀態異常: active=false")
}
server.Status = "正常"
case "推理":
server.Status = "就绪"
default:
server.Status = "異常"
}
// 檢查服務器是否正常
return nil
}