同步服务器当前模型
This commit is contained in:
@@ -121,6 +121,24 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) {
|
||||
configs.ORMDB().Take(&server)
|
||||
}
|
||||
|
||||
// 排队等待切换模型(先检查是否已经切换完成)
|
||||
for {
|
||||
var form = struct {
|
||||
Components []struct {
|
||||
ID int `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Props struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
} `json:"components"`
|
||||
}{}
|
||||
// 检查当前是否为目标模型, 不是则执行切换模型 http://
|
||||
if err := goreq.Get(fmt.Sprintf("http://%s:%d/config", server.IP, server.Port)).Do().BindJSON(&form); err != nil {
|
||||
log.Println("获取推理机配置失败:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 发送的参数
|
||||
var img = image_list[0]
|
||||
var datx map[string]interface{} = make(map[string]interface{})
|
||||
|
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/errors"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
||||
cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312"
|
||||
"github.com/zhshch2002/goreq"
|
||||
)
|
||||
|
||||
type ModelList []int
|
||||
@@ -31,13 +32,14 @@ func (list ModelList) Value() (driver.Value, error) {
|
||||
type Server struct {
|
||||
ID string `json:"id" gorm:"primary_key"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // (训练|推理)
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
|
||||
UserName string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Models ModelList `json:"models"`
|
||||
Type string `json:"type"` // (训练|推理)
|
||||
IP string `json:"ip"` // 服务器IP
|
||||
Port int `json:"port"` // 7860
|
||||
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
|
||||
UserName string `json:"username"` // 用户名
|
||||
Password string `json:"password"` // 用户密码
|
||||
Models ModelList `json:"models"` // 服务器中所有模型
|
||||
ModelID int `json:"model_id"` // 当前加载的模型
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
@@ -132,6 +134,28 @@ func (server *Server) InitModels() (err error) {
|
||||
if err := configs.ORMDB().Save(&server).Error; err != nil {
|
||||
return fmt.Errorf("更新数据库失败: %v", err)
|
||||
}
|
||||
|
||||
// 检查服务器当前加载的模型
|
||||
if server.ModelID == 0 {
|
||||
var form = struct {
|
||||
SdCheckpointHash string `json:"sd_checkpoint_hash"`
|
||||
SdModelCheckpoint string `json:"sd_model_checkpoint"`
|
||||
}{}
|
||||
// 检查当前是否为目标模型, 不是则执行切换模型
|
||||
if err := goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).Do().BindJSON(&form); err != nil {
|
||||
return fmt.Errorf("获取推理机配置失败: %v", err)
|
||||
}
|
||||
fmt.Println("当前模型:", form.SdModelCheckpoint)
|
||||
var model Model
|
||||
if err := configs.ORMDB().Where("model_checkpoint = ?", form.SdModelCheckpoint).First(&model).Error; err != nil {
|
||||
return fmt.Errorf("获取模型信息失败: %v", err)
|
||||
}
|
||||
server.ModelID = model.ID
|
||||
if err := configs.ORMDB().Save(&server).Error; err != nil {
|
||||
return fmt.Errorf("更新数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user