diff --git a/models/Model.go b/models/Model.go index 8e3ff5e..cd90383 100644 --- a/models/Model.go +++ b/models/Model.go @@ -121,6 +121,24 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) { configs.ORMDB().Take(&server) } + // 排队等待切换模型(先检查是否已经切换完成) + for { + var form = struct { + Components []struct { + ID int `json:"id"` + Type string `json:"type"` + Props struct { + Value string `json:"value"` + } + } `json:"components"` + }{} + // 检查当前是否为目标模型, 不是则执行切换模型 http:// + if err := goreq.Get(fmt.Sprintf("http://%s:%d/config", server.IP, server.Port)).Do().BindJSON(&form); err != nil { + log.Println("获取推理机配置失败:", err) + return + } + } + // 发送的参数 var img = image_list[0] var datx map[string]interface{} = make(map[string]interface{}) diff --git a/models/server.go b/models/server.go index 8ff309b..9a3e2f8 100644 --- a/models/server.go +++ b/models/server.go @@ -16,6 +16,7 @@ import ( "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/errors" "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312" + "github.com/zhshch2002/goreq" ) type ModelList []int @@ -31,13 +32,14 @@ func (list ModelList) Value() (driver.Value, error) { type Server struct { ID string `json:"id" gorm:"primary_key"` Name string `json:"name"` - Type string `json:"type"` // (训练|推理) - IP string `json:"ip"` - Port int `json:"port"` - Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中) - UserName string `json:"username"` - Password string `json:"password"` - Models ModelList `json:"models"` + Type string `json:"type"` // (训练|推理) + IP string `json:"ip"` // 服务器IP + Port int `json:"port"` // 7860 + Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中) + UserName string `json:"username"` // 用户名 + Password string `json:"password"` // 用户密码 + Models ModelList `json:"models"` // 服务器中所有模型 + ModelID int `json:"model_id"` // 当前加载的模型 CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } @@ -132,6 +134,28 @@ func (server *Server) InitModels() (err error) { if err := configs.ORMDB().Save(&server).Error; err != nil { return fmt.Errorf("更新数据库失败: %v", err) } + + // 检查服务器当前加载的模型 + if server.ModelID == 0 { + var form = struct { + SdCheckpointHash string `json:"sd_checkpoint_hash"` + SdModelCheckpoint string `json:"sd_model_checkpoint"` + }{} + // 检查当前是否为目标模型, 不是则执行切换模型 + if err := goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).Do().BindJSON(&form); err != nil { + return fmt.Errorf("获取推理机配置失败: %v", err) + } + fmt.Println("当前模型:", form.SdModelCheckpoint) + var model Model + if err := configs.ORMDB().Where("model_checkpoint = ?", form.SdModelCheckpoint).First(&model).Error; err != nil { + return fmt.Errorf("获取模型信息失败: %v", err) + } + server.ModelID = model.ID + if err := configs.ORMDB().Save(&server).Error; err != nil { + return fmt.Errorf("更新数据库失败: %v", err) + } + } + return }