模型安全加载

This commit is contained in:
2023-07-10 08:32:51 +08:00
parent 30de6b2547
commit 2ef095c8e2
2 changed files with 34 additions and 37 deletions

16
main.go
View File

@@ -93,12 +93,18 @@ func main() {
r.HandleFunc("/img/{id}", routers.WebpGet).Methods("GET") r.HandleFunc("/img/{id}", routers.WebpGet).Methods("GET")
// 設定靜態資源 (前端) 位于dist目录下, 并且为 图片和 js/css 设置缓存时间为7天 // 設定靜態資源 (前端) 位于dist目录下
cacheTime := 7 * 24 * time.Hour cacheTime := 7 * 24 * time.Hour
r.PathPrefix("/images/").Handler(http.StripPrefix("/images/", http.FileServer(http.Dir("./data/")))).Headers("Cache-Control", fmt.Sprintf("max-age=%d", int(cacheTime.Seconds()))).Methods("GET") r.PathPrefix("/images/").Handler(http.FileServer(http.Dir("./data/"))).Headers("Cache-Control", fmt.Sprintf("max-age=%d", int(cacheTime.Seconds())))
r.PathPrefix("/static/").Handler(http.StripPrefix("/", http.FileServer(http.Dir("./dist/static/")))).Headers("Cache-Control", fmt.Sprintf("max-age=%d", int(cacheTime.Seconds()))).Methods("GET") r.PathPrefix("/").Handler(http.FileServer(http.Dir("./dist/")))
r.PathPrefix("/favicon.ico").Handler(http.StripPrefix("/", http.FileServer(http.Dir("./dist/favicon.ico")))).Headers("Cache-Control", fmt.Sprintf("max-age=%d", int(cacheTime.Seconds()))).Methods("GET")
r.PathPrefix("/").Handler(http.StripPrefix("/", http.FileServer(http.Dir("./dist/")))).Methods("GET") //// 設定靜態資源 (前端) 位于dist目录下, 并且为 图片和 js/css 设置缓存时间为7天
//cacheTime := 7 * 24 * time.Hour
//r.PathPrefix("/images/").Handler(http.StripPrefix("/images/", http.FileServer(http.Dir("./data/images/")))).Headers("Cache-Control", fmt.Sprintf("max-age=%d", int(cacheTime.Seconds()))).Methods("GET")
//r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir("./dist/static/")))).Headers("Cache-Control", fmt.Sprintf("max-age=%d", int(cacheTime.Seconds()))).Methods("GET")
//r.PathPrefix("/favicon.ico").Handler(http.StripPrefix("/", http.FileServer(http.Dir("./dist/favicon.ico")))).Headers("Cache-Control", fmt.Sprintf("max-age=%d", int(cacheTime.Seconds()))).Methods("GET")
//r.PathPrefix("/").Handler(http.StripPrefix("/", http.FileServer(http.Dir("./dist/")))).Methods("GET")
//r.PathPrefix("/").Handler(http.StripPrefix("/", http.FileServer(http.Dir("./dist/")))).Headers("Cache-Control", fmt.Sprintf("max-age=%d", int(cacheTime.Seconds()))).Methods("GET")
log.Println("Web Server is running on http://localhost:8080") log.Println("Web Server is running on http://localhost:8080")
if err := http.ListenAndServe(":8080", r); err != nil { if err := http.ListenAndServe(":8080", r); err != nil {

View File

@@ -94,7 +94,7 @@ func InitDefaultServer() (err error) {
// 初始化服务器中的模型列表 // 初始化服务器中的模型列表
func (server *Server) InitModels() (err error) { func (server *Server) InitModels() (err error) {
// 先让服务器更新模型列表 /sdapi/v1/reload-checkpoint // 刷新检查点 /sdapi/v1/reload-checkpoint
resp_update, err := http.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port), "application/json", nil) resp_update, err := http.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port), "application/json", nil)
if err != nil { if err != nil {
return fmt.Errorf("更新服务器中的模型列表失败: %v", err) return fmt.Errorf("更新服务器中的模型列表失败: %v", err)
@@ -117,10 +117,6 @@ func (server *Server) InitModels() (err error) {
// 从数据库检查此模型hash是否存在 // 从数据库检查此模型hash是否存在
for _, item := range data_lora { for _, item := range data_lora {
// 打印为格式化的JSON
//b, _ := json.MarshalIndent(item, "", " ")
//fmt.Println(string(b))
var model Model var model Model
if err := configs.ORMDB().Where("hash = ?", item["path"].(string)).First(&model).Error; err != nil { if err := configs.ORMDB().Where("hash = ?", item["path"].(string)).First(&model).Error; err != nil {
// 不存在则添加 // 不存在则添加
@@ -146,19 +142,27 @@ func (server *Server) InitModels() (err error) {
return fmt.Errorf("更新数据库失败: %v", err) return fmt.Errorf("更新数据库失败: %v", err)
} }
// 模型HASH生成 // 刷新检查点
if err = goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port)).Do().Err; err != nil {
return fmt.Errorf("刷新检查点失败: %v", err)
}
// 获取服务器中的模型列表(ckp)
dexs := []struct { dexs := []struct {
Name string `json:"model_name"` Name string `json:"model_name"`
Hash string `json:"sha256"` Hash string `json:"sha256"`
ModelCheckpoint string `json:"title"` ModelCheckpoint string `json:"title"`
ModelPath string `json:"filename"` ModelPath string `json:"filename"`
}{} }{}
if err = goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil {
// 获取服务器中的模型列表(ckp)
if err = goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil {
return fmt.Errorf("获取服务器中的模型列表失败: %v", err) return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
} }
for _, item := range dexs { for _, item := range dexs {
// 如果hash为空, 则逐一加载这些模型使其生成hash // 如果hash为空, 则逐一加载这些模型使其生成hash
if item.Hash == "" { if item.Hash == "" {
fmt.Println("加载模型:", item.ModelCheckpoint)
if err := goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).SetJsonBody(map[string]interface{}{ if err := goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).SetJsonBody(map[string]interface{}{
"sd_model_checkpoint": item.ModelCheckpoint, "sd_model_checkpoint": item.ModelCheckpoint,
"CLIP_stop_at_last_layers": 2, "CLIP_stop_at_last_layers": 2,
@@ -168,40 +172,27 @@ func (server *Server) InitModels() (err error) {
} }
} }
// 获取服务器中的模型列表(ckp) // 重新获取服务器中的模型列表(ckp), 忽略hash为空的模型
resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)) if err = goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil {
if err != nil {
return fmt.Errorf("获取服务器中的模型列表失败: %v", err) return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
} }
defer resp.Body.Close() for _, item := range dexs {
if item.Hash == "" {
// 解码JSON (数组) fmt.Println("忽略模型:", item.ModelCheckpoint)
var data []map[string]interface{} continue
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { }
return err
}
// 如果hash为空, 则逐一加载这些模型使其生成hash
// 然后再次更新服务器中的模型列表
// 从数据库检查此模型hash是否存在
for _, item := range data {
// 打印为格式化的JSON
b, _ := json.MarshalIndent(item, "", " ")
fmt.Println(string(b))
var model Model var model Model
if err := configs.ORMDB().Where("hash = ?", item["sha256"].(string)).First(&model).Error; err != nil { if err := configs.ORMDB().Where("hash = ?", item.Hash).First(&model).Error; err != nil {
// 不存在则添加 // 不存在则添加
model = Model{ model = Model{
Name: item["model_name"].(string), Name: item.Name,
Hash: item["sha256"].(string), Hash: item.Hash,
ModelCheckpoint: item["title"].(string), ModelCheckpoint: item.ModelCheckpoint,
ModelPath: item["filename"].(string), // TODO: 下载到本地 ModelPath: item.ModelPath,
ServerID: server.ID, ServerID: server.ID,
Type: "ckp", Type: "ckp",
} }
// TODO: 下载到本地
// 添加到数据库 // 添加到数据库
if err := configs.ORMDB().Create(&model).Error; err != nil { if err := configs.ORMDB().Create(&model).Error; err != nil {
return fmt.Errorf("添加模型到数据库失败: %v", err) return fmt.Errorf("添加模型到数据库失败: %v", err)