package routers import ( "crypto/sha256" "encoding/json" "fmt" "io" "io/ioutil" "log" "main/configs" "main/models" "main/utils" "net/http" "os" "strconv" "github.com/gorilla/mux" ) func init() { models_update() } func models_update() { // 初始化模型路由: 检查本地模型目录是否存在, 不存在则创建 if _, err := os.Stat("data/models"); err != nil { if err := os.MkdirAll("data/models", 0777); err != nil { log.Println(err) } } // 检查模型目录中是否存在模型文件, 如果存在且数据库中未记录, 则将模型信息写入数据库 if files, err := ioutil.ReadDir("data/models"); err == nil { for _, file := range files { if file.IsDir() { continue } log.Println("检查模型是否存在:", file.Name()) // 检查文件是否已经存在 var model models.Model if err := configs.ORMDB().Take(&model, "name = ?", file.Name()).Error; err == nil { continue } // 计算文件的 sha256 值 f, err := os.Open("data/models/" + file.Name()) if err != nil { log.Println(err) continue } defer f.Close() hash := sha256.New() if _, err := io.Copy(hash, f); err != nil { log.Println(err) continue } model.Name = file.Name() model.Hash = fmt.Sprintf("%x", hash.Sum(nil)) model.ModelPath = "data/models/" + file.Name() model.Type = "ckp" model.Status = "success" model.Progress = 100 model.Tags = []string{"平台模型"} log.Println("模型不存在, 添加到数据库:", file.Name()) configs.ORMDB().Create(&model) } } } // 更新检查本地模型列表 func ModelsUpdate(w http.ResponseWriter, r *http.Request) { models_update() w.Write([]byte("ok")) } // 獲取模型列表 func ModelsGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) var model_list []models.Model db := configs.ORMDB() // 按照 user_id 篩選 if user_id := utils.ParamInt(r.URL.Query().Get("user_id"), 0); user_id > 0 { db = db.Where("user_id = ?", user_id) } // 按照 star 篩選 if star := utils.ParamInt(r.URL.Query().Get("star"), 0); star > 0 { db = db.Where("stars LIKE ?", "%"+strconv.Itoa(star)+"%") } // 按照 name 模糊搜索 if name := r.URL.Query().Get("name"); name != "" { db = db.Where("name LIKE ?", "%"+name+"%") } // 按照 type 篩選 if model_type := r.URL.Query().Get("type"); model_type != "" { db = db.Where("type = ?", model_type) } // 按照 status 篩選 if status := r.URL.Query().Get("status"); status != "" { db = db.Where("status = ?", status) } // 按照 tag 篩選 if tag := r.URL.Query().Get("tag"); tag != "" { db = db.Where("tags LIKE ?", "%"+tag+"%") } db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&model_list) for _, model := range model_list { listview.List = append(listview.List, model) } db.Model(&models.Model{}).Count(&listview.Total) listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } // 創建模型(訓練新模型) func ModelsPost(w http.ResponseWriter, r *http.Request) { models.AccountRead(w, r, func(account *models.Account) { fmt.Println(account) // 創建模型 var model models.Model body, err := ioutil.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error())) return } defer r.Body.Close() if err = json.Unmarshal(body, &model); err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error())) return } if model.Name == "" { model.Name = utils.RandomString(8) } if model.Type == "" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("模型類型不能為空(recommend|lora|ckp|hyper|ti)")) return } if model.TriggerWords == "" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("觸發詞不能為空")) return } if model.BaseModel == "" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("基礎模型不能為空(SD1.5|SD2)")) return } if model.Epochs <= 0 { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("訓練輪數不能小於0")) return } if model.Tags == nil { model.Tags = []string{} } model.UserID = account.ID model.Status = "initial" if err := configs.ORMDB().Create(&model).Error; err != nil { log.Println(err) return } // 直接提交訓練任務 // go model.Train() // 返回創建的模型 w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(model)) }) } // 獲取模型詳情 func ModelItemGet(w http.ResponseWriter, r *http.Request) { var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil { w.WriteHeader(http.StatusNotFound) w.Write([]byte(err.Error())) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(model)) } // 更新模型 func ModelItemPatch(w http.ResponseWriter, r *http.Request) { var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil { w.WriteHeader(http.StatusNotFound) w.Write([]byte(err.Error())) return } // 取出更新数据 var model_new models.Model body, err := ioutil.ReadAll(r.Body) if err != nil { log.Println(err) return } defer r.Body.Close() if err = json.Unmarshal(body, &model_new); err != nil { log.Println(err) return } // 字段不爲空且不等於原始數據時更新 if model_new.Name != "" && model_new.Name != model.Name { model.Name = model_new.Name } if model_new.Type != "" && model_new.Type != model.Type { model.Type = model_new.Type } if model_new.Status != "" && model_new.Status != model.Status { model.Status = model_new.Status // 如果狀態被改變爲 ready, 將模型發送到訓練隊列 if model.Status == "ready" { model.Status = "training" //go model.Train() } } if model_new.Image != "" && model_new.Image != model.Image { model.Image = model_new.Image } if model_new.TriggerWords != "" && model_new.TriggerWords != model.TriggerWords { model.TriggerWords = model_new.TriggerWords } if model_new.BaseModel != "" && model_new.BaseModel != model.BaseModel { model.BaseModel = model_new.BaseModel } if model_new.ModelPath != "" && model_new.ModelPath != model.ModelPath { model.ModelPath = model_new.ModelPath } if model_new.Hash != "" && model_new.Hash != model.Hash { model.Hash = model_new.Hash } if model_new.Epochs != 0 && model_new.Epochs != model.Epochs { model.Epochs = model_new.Epochs } if model_new.Progress != 0 && model_new.Progress != model.Progress { model.Progress = model_new.Progress } if model_new.Tags != nil && len(model_new.Tags) != len(model.Tags) { model.Tags = model_new.Tags } // 執行更新 if err := configs.ORMDB().Save(&model).Error; err != nil { log.Println(err) return } // 返回更新後的數據 w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(model)) } // 刪除模型 func ModelItemDelete(w http.ResponseWriter, r *http.Request) { var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil { w.WriteHeader(http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(model)) }