同步服务器当前模型

This commit is contained in:
2023-07-02 05:29:04 +08:00
parent 4afb719d13
commit 30286dd0c0
2 changed files with 49 additions and 7 deletions

View File

@@ -121,6 +121,24 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) {
configs.ORMDB().Take(&server) configs.ORMDB().Take(&server)
} }
// 排队等待切换模型(先检查是否已经切换完成)
for {
var form = struct {
Components []struct {
ID int `json:"id"`
Type string `json:"type"`
Props struct {
Value string `json:"value"`
}
} `json:"components"`
}{}
// 检查当前是否为目标模型, 不是则执行切换模型 http://
if err := goreq.Get(fmt.Sprintf("http://%s:%d/config", server.IP, server.Port)).Do().BindJSON(&form); err != nil {
log.Println("获取推理机配置失败:", err)
return
}
}
// 发送的参数 // 发送的参数
var img = image_list[0] var img = image_list[0]
var datx map[string]interface{} = make(map[string]interface{}) var datx map[string]interface{} = make(map[string]interface{})

View File

@@ -16,6 +16,7 @@ import (
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/errors" "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/errors"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312" cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312"
"github.com/zhshch2002/goreq"
) )
type ModelList []int type ModelList []int
@@ -32,12 +33,13 @@ type Server struct {
ID string `json:"id" gorm:"primary_key"` ID string `json:"id" gorm:"primary_key"`
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type"` // (训练|推理) Type string `json:"type"` // (训练|推理)
IP string `json:"ip"` IP string `json:"ip"` // 服务器IP
Port int `json:"port"` Port int `json:"port"` // 7860
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中) Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
UserName string `json:"username"` UserName string `json:"username"` // 用户名
Password string `json:"password"` Password string `json:"password"` // 用户密码
Models ModelList `json:"models"` Models ModelList `json:"models"` // 服务器中所有模型
ModelID int `json:"model_id"` // 当前加载的模型
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
} }
@@ -132,6 +134,28 @@ func (server *Server) InitModels() (err error) {
if err := configs.ORMDB().Save(&server).Error; err != nil { if err := configs.ORMDB().Save(&server).Error; err != nil {
return fmt.Errorf("更新数据库失败: %v", err) return fmt.Errorf("更新数据库失败: %v", err)
} }
// 检查服务器当前加载的模型
if server.ModelID == 0 {
var form = struct {
SdCheckpointHash string `json:"sd_checkpoint_hash"`
SdModelCheckpoint string `json:"sd_model_checkpoint"`
}{}
// 检查当前是否为目标模型, 不是则执行切换模型
if err := goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).Do().BindJSON(&form); err != nil {
return fmt.Errorf("获取推理机配置失败: %v", err)
}
fmt.Println("当前模型:", form.SdModelCheckpoint)
var model Model
if err := configs.ORMDB().Where("model_checkpoint = ?", form.SdModelCheckpoint).First(&model).Error; err != nil {
return fmt.Errorf("获取模型信息失败: %v", err)
}
server.ModelID = model.ID
if err := configs.ORMDB().Save(&server).Error; err != nil {
return fmt.Errorf("更新数据库失败: %v", err)
}
}
return return
} }