From f8a51382f581247dd979422cc08274fd4bf2a189 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A1=9C=E8=8F=AF?= Date: Sat, 8 Jul 2023 11:24:59 +0800 Subject: [PATCH] =?UTF-8?q?hash=20=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/server.go | 38 ++++++++++++++++++++++++++++++++++++-- routers/models.go | 15 +++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/models/server.go b/models/server.go index f078978..5893273 100644 --- a/models/server.go +++ b/models/server.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "log" "main/configs" "net/http" "path/filepath" @@ -93,6 +94,14 @@ func InitDefaultServer() (err error) { // 初始化服务器中的模型列表 func (server *Server) InitModels() (err error) { + // 先让服务器更新模型列表 /sdapi/v1/reload-checkpoint + resp_update, err := http.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port), "application/json", nil) + if err != nil { + return fmt.Errorf("更新服务器中的模型列表失败: %v", err) + } + defer resp_update.Body.Close() + log.Println("更新服务器中的模型列表:", resp_update.Status) + // 获取服务器中的模型列表(Lora) resp_lora, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/loras", server.IP, server.Port)) if err != nil { @@ -109,8 +118,8 @@ func (server *Server) InitModels() (err error) { // 从数据库检查此模型hash是否存在 for _, item := range data_lora { // 打印为格式化的JSON - b, _ := json.MarshalIndent(item, "", " ") - fmt.Println(string(b)) + //b, _ := json.MarshalIndent(item, "", " ") + //fmt.Println(string(b)) var model Model if err := configs.ORMDB().Where("hash = ?", item["path"].(string)).First(&model).Error; err != nil { @@ -137,6 +146,28 @@ func (server *Server) InitModels() (err error) { return fmt.Errorf("更新数据库失败: %v", err) } + // 模型HASH生成 + dexs := []struct { + Name string `json:"model_name"` + Hash string `json:"sha256"` + ModelCheckpoint string `json:"title"` + ModelPath string `json:"filename"` + }{} + if err = goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil { + return fmt.Errorf("获取服务器中的模型列表失败: %v", err) + } + for _, item := range dexs { + // 如果hash为空, 则逐一加载这些模型使其生成hash + if item.Hash == "" { + if err := goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).SetJsonBody(map[string]interface{}{ + "sd_model_checkpoint": item.ModelCheckpoint, + "CLIP_stop_at_last_layers": 2, + }).Do().Error(); err != nil { + return fmt.Errorf("加载模型失败: %v", err) + } + } + } + // 获取服务器中的模型列表(ckp) resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)) if err != nil { @@ -150,6 +181,9 @@ func (server *Server) InitModels() (err error) { return err } + // 如果hash为空, 则逐一加载这些模型使其生成hash + // 然后再次更新服务器中的模型列表 + // 从数据库检查此模型hash是否存在 for _, item := range data { // 打印为格式化的JSON diff --git a/routers/models.go b/routers/models.go index 03b57a4..280bd1c 100644 --- a/routers/models.go +++ b/routers/models.go @@ -21,7 +21,22 @@ func init() { models_update() } +// 检查服务器中的模型列表 +func server_models_update() { + var servers []models.Server + configs.ORMDB().Find(&servers) + fmt.Println("开始检查服务器中的模型列表") + for _, server := range servers { + fmt.Println("检查服务器中的模型列表:", server.Name) + server.InitModels() + } + fmt.Println("检查服务器中的模型列表完成") +} + +// 检查本地模型列表 func models_update() { + server_models_update() + // 初始化模型路由: 检查本地模型目录是否存在, 不存在则创建 if _, err := os.Stat("data/models"); err != nil { if err := os.MkdirAll("data/models", 0777); err != nil {