模型安全加载

This commit is contained in:
2023-07-10 08:32:51 +08:00
parent 30de6b2547
commit 2ef095c8e2
2 changed files with 34 additions and 37 deletions

View File

@@ -94,7 +94,7 @@ func InitDefaultServer() (err error) {
// 初始化服务器中的模型列表
func (server *Server) InitModels() (err error) {
// 先让服务器更新模型列表 /sdapi/v1/reload-checkpoint
// 刷新检查点 /sdapi/v1/reload-checkpoint
resp_update, err := http.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port), "application/json", nil)
if err != nil {
return fmt.Errorf("更新服务器中的模型列表失败: %v", err)
@@ -117,10 +117,6 @@ func (server *Server) InitModels() (err error) {
// 从数据库检查此模型hash是否存在
for _, item := range data_lora {
// 打印为格式化的JSON
//b, _ := json.MarshalIndent(item, "", " ")
//fmt.Println(string(b))
var model Model
if err := configs.ORMDB().Where("hash = ?", item["path"].(string)).First(&model).Error; err != nil {
// 不存在则添加
@@ -146,19 +142,27 @@ func (server *Server) InitModels() (err error) {
return fmt.Errorf("更新数据库失败: %v", err)
}
// 模型HASH生成
// 刷新检查点
if err = goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port)).Do().Err; err != nil {
return fmt.Errorf("刷新检查点失败: %v", err)
}
// 获取服务器中的模型列表(ckp)
dexs := []struct {
Name string `json:"model_name"`
Hash string `json:"sha256"`
ModelCheckpoint string `json:"title"`
ModelPath string `json:"filename"`
}{}
if err = goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil {
// 获取服务器中的模型列表(ckp)
if err = goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil {
return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
}
for _, item := range dexs {
// 如果hash为空, 则逐一加载这些模型使其生成hash
if item.Hash == "" {
fmt.Println("加载模型:", item.ModelCheckpoint)
if err := goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).SetJsonBody(map[string]interface{}{
"sd_model_checkpoint": item.ModelCheckpoint,
"CLIP_stop_at_last_layers": 2,
@@ -168,40 +172,27 @@ func (server *Server) InitModels() (err error) {
}
}
// 获取服务器中的模型列表(ckp)
resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port))
if err != nil {
// 重新获取服务器中的模型列表(ckp), 忽略hash为空的模型
if err = goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil {
return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
}
defer resp.Body.Close()
// 解码JSON (数组)
var data []map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return err
}
// 如果hash为空, 则逐一加载这些模型使其生成hash
// 然后再次更新服务器中的模型列表
// 从数据库检查此模型hash是否存在
for _, item := range data {
// 打印为格式化的JSON
b, _ := json.MarshalIndent(item, "", " ")
fmt.Println(string(b))
for _, item := range dexs {
if item.Hash == "" {
fmt.Println("忽略模型:", item.ModelCheckpoint)
continue
}
var model Model
if err := configs.ORMDB().Where("hash = ?", item["sha256"].(string)).First(&model).Error; err != nil {
if err := configs.ORMDB().Where("hash = ?", item.Hash).First(&model).Error; err != nil {
// 不存在则添加
model = Model{
Name: item["model_name"].(string),
Hash: item["sha256"].(string),
ModelCheckpoint: item["title"].(string),
ModelPath: item["filename"].(string), // TODO: 下载到本地
Name: item.Name,
Hash: item.Hash,
ModelCheckpoint: item.ModelCheckpoint,
ModelPath: item.ModelPath,
ServerID: server.ID,
Type: "ckp",
}
// TODO: 下载到本地
// 添加到数据库
if err := configs.ORMDB().Create(&model).Error; err != nil {
return fmt.Errorf("添加模型到数据库失败: %v", err)