diff --git a/routers/models.go b/routers/models.go index c1aea6c..6dfe0a1 100644 --- a/routers/models.go +++ b/routers/models.go @@ -1,14 +1,17 @@ package routers import ( + "crypto/sha256" "encoding/json" "fmt" + "io" "io/ioutil" "log" "main/configs" "main/models" "main/utils" "net/http" + "os" "strconv" "github.com/gorilla/mux" @@ -17,6 +20,54 @@ import ( var manager = models.NewWebSocketManager() +func init() { + // 初始化模型路由: 检查本地模型目录是否存在, 不存在则创建 + if _, err := os.Stat("data/models"); err != nil { + if err := os.MkdirAll("data/models", 0777); err != nil { + log.Println(err) + } + } + // 检查模型目录中是否存在模型文件, 如果存在且数据库中未记录, 则将模型信息写入数据库 + if files, err := ioutil.ReadDir("data/models"); err == nil { + for _, file := range files { + if file.IsDir() { + continue + } + + log.Println("检查模型是否存在:", file.Name()) + + // 检查文件是否已经存在 + var model models.Model + if err := configs.ORMDB().Take(&model, "name = ?", file.Name()).Error; err == nil { + continue + } + + // 计算文件的 sha256 值 + f, err := os.Open("data/models/" + file.Name()) + if err != nil { + log.Println(err) + continue + } + defer f.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, f); err != nil { + log.Println(err) + continue + } + + model.Name = file.Name() + model.Hash = fmt.Sprintf("%x", hash.Sum(nil)) + model.ModelPath = "data/models/" + file.Name() + model.Type = "ckp" + model.Status = "success" + + log.Println("模型不存在, 添加到数据库:", file.Name()) + configs.ORMDB().Create(&model) + } + } +} + // 獲取模型列表 func ModelsGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView