321 lines
8.4 KiB
Go
321 lines
8.4 KiB
Go
package routers
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"main/configs"
|
|
"main/models"
|
|
"main/utils"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
|
|
"github.com/gorilla/mux"
|
|
)
|
|
|
|
func init() {
|
|
models_update()
|
|
}
|
|
|
|
func models_update() {
|
|
// 初始化模型路由: 检查本地模型目录是否存在, 不存在则创建
|
|
if _, err := os.Stat("data/models"); err != nil {
|
|
if err := os.MkdirAll("data/models", 0777); err != nil {
|
|
log.Println(err)
|
|
}
|
|
}
|
|
// 检查模型目录中是否存在模型文件, 如果存在且数据库中未记录, 则将模型信息写入数据库
|
|
if files, err := ioutil.ReadDir("data/models"); err == nil {
|
|
for _, file := range files {
|
|
if file.IsDir() {
|
|
continue
|
|
}
|
|
|
|
log.Println("检查模型是否存在:", file.Name())
|
|
|
|
// 检查文件是否已经存在
|
|
var model models.Model
|
|
if err := configs.ORMDB().Take(&model, "name = ?", file.Name()).Error; err == nil {
|
|
continue
|
|
}
|
|
|
|
// 计算文件的 sha256 值
|
|
f, err := os.Open("data/models/" + file.Name())
|
|
if err != nil {
|
|
log.Println(err)
|
|
continue
|
|
}
|
|
defer f.Close()
|
|
|
|
hash := sha256.New()
|
|
if _, err := io.Copy(hash, f); err != nil {
|
|
log.Println(err)
|
|
continue
|
|
}
|
|
|
|
model.Name = file.Name()
|
|
model.Hash = fmt.Sprintf("%x", hash.Sum(nil))
|
|
model.ModelPath = "data/models/" + file.Name()
|
|
model.Type = "ckp"
|
|
model.Status = "success"
|
|
model.Progress = 100
|
|
model.Tags = []string{"平台模型"}
|
|
|
|
log.Println("模型不存在, 添加到数据库:", file.Name())
|
|
configs.ORMDB().Create(&model)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 更新检查本地模型列表
|
|
func ModelsUpdate(w http.ResponseWriter, r *http.Request) {
|
|
models_update()
|
|
w.Write([]byte("ok"))
|
|
}
|
|
|
|
// 獲取模型列表
|
|
func ModelsGet(w http.ResponseWriter, r *http.Request) {
|
|
var listview models.ListView
|
|
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
|
|
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
|
|
|
|
var model_list []models.Model
|
|
db := configs.ORMDB()
|
|
|
|
// 按照 user_id 篩選
|
|
if user_id := utils.ParamInt(r.URL.Query().Get("user_id"), 0); user_id > 0 {
|
|
db = db.Where("user_id = ?", user_id)
|
|
}
|
|
|
|
// 按照 star 篩選
|
|
if star := utils.ParamInt(r.URL.Query().Get("star"), 0); star > 0 {
|
|
db = db.Where("stars LIKE ?", "%"+strconv.Itoa(star)+"%")
|
|
}
|
|
|
|
// 按照 name 模糊搜索
|
|
if name := r.URL.Query().Get("name"); name != "" {
|
|
db = db.Where("name LIKE ?", "%"+name+"%")
|
|
}
|
|
|
|
// 按照 type 篩選
|
|
if model_type := r.URL.Query().Get("type"); model_type != "" {
|
|
db = db.Where("type = ?", model_type)
|
|
}
|
|
|
|
// 按照 status 篩選
|
|
if status := r.URL.Query().Get("status"); status != "" {
|
|
db = db.Where("status = ?", status)
|
|
}
|
|
|
|
// 按照 tag 篩選
|
|
if tag := r.URL.Query().Get("tag"); tag != "" {
|
|
db = db.Where("tags LIKE ?", "%"+tag+"%")
|
|
}
|
|
|
|
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&model_list)
|
|
for _, model := range model_list {
|
|
listview.List = append(listview.List, model)
|
|
}
|
|
db.Model(&models.Model{}).Count(&listview.Total)
|
|
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
|
|
listview.WriteJSON(w)
|
|
}
|
|
|
|
// 創建模型(訓練新模型)
|
|
func ModelsPost(w http.ResponseWriter, r *http.Request) {
|
|
models.AccountRead(w, r, func(account *models.Account) {
|
|
fmt.Println(account)
|
|
|
|
// 創建模型
|
|
var model models.Model
|
|
body, err := ioutil.ReadAll(r.Body)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
|
|
if err = json.Unmarshal(body, &model); err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
if model.Name == "" {
|
|
model.Name = utils.RandomString(8)
|
|
}
|
|
|
|
if model.Type == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("模型類型不能為空(recommend|lora|ckp|hyper|ti)"))
|
|
return
|
|
}
|
|
|
|
if model.TriggerWords == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("觸發詞不能為空"))
|
|
return
|
|
}
|
|
|
|
if model.BaseModel == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("基礎模型不能為空(SD1.5|SD2)"))
|
|
return
|
|
}
|
|
|
|
if model.Epochs <= 0 {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("訓練輪數不能小於0"))
|
|
return
|
|
}
|
|
|
|
if model.Tags == nil {
|
|
model.Tags = []string{}
|
|
}
|
|
|
|
model.UserID = account.ID
|
|
model.Status = "initial"
|
|
|
|
if err := configs.ORMDB().Create(&model).Error; err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 直接提交訓練任務
|
|
// go model.Train()
|
|
|
|
// 返回創建的模型
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(model))
|
|
})
|
|
}
|
|
|
|
// 獲取模型詳情
|
|
func ModelItemGet(w http.ResponseWriter, r *http.Request) {
|
|
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(model))
|
|
}
|
|
|
|
// 更新模型
|
|
func ModelItemPatch(w http.ResponseWriter, r *http.Request) {
|
|
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
// 取出更新数据
|
|
var model_new models.Model
|
|
body, err := ioutil.ReadAll(r.Body)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
if err = json.Unmarshal(body, &model_new); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 字段不爲空且不等於原始數據時更新
|
|
if model_new.Name != "" && model_new.Name != model.Name {
|
|
model.Name = model_new.Name
|
|
}
|
|
if model_new.Info != "" && model_new.Info != model.Info {
|
|
model.Info = model_new.Info
|
|
}
|
|
if model_new.Type != "" && model_new.Type != model.Type {
|
|
model.Type = model_new.Type
|
|
}
|
|
if model_new.Status != "" && model_new.Status != model.Status {
|
|
model.Status = model_new.Status
|
|
// 如果狀態被改變爲 ready, 將模型發送到訓練隊列
|
|
if model.Status == "ready" {
|
|
model.Status = "training"
|
|
//go model.Train()
|
|
}
|
|
}
|
|
if model_new.Preview != "" && model_new.Preview != model.Preview {
|
|
model.Preview = model_new.Preview
|
|
}
|
|
if model_new.TriggerWords != "" && model_new.TriggerWords != model.TriggerWords {
|
|
model.TriggerWords = model_new.TriggerWords
|
|
}
|
|
if model_new.BaseModel != "" && model_new.BaseModel != model.BaseModel {
|
|
model.BaseModel = model_new.BaseModel
|
|
}
|
|
if model_new.ModelPath != "" && model_new.ModelPath != model.ModelPath {
|
|
model.ModelPath = model_new.ModelPath
|
|
}
|
|
if model_new.Hash != "" && model_new.Hash != model.Hash {
|
|
model.Hash = model_new.Hash
|
|
}
|
|
if model_new.Epochs != 0 && model_new.Epochs != model.Epochs {
|
|
model.Epochs = model_new.Epochs
|
|
}
|
|
if model_new.Progress != 0 && model_new.Progress != model.Progress {
|
|
model.Progress = model_new.Progress
|
|
}
|
|
if model_new.Tags != nil && len(model_new.Tags) != len(model.Tags) {
|
|
model.Tags = model_new.Tags
|
|
}
|
|
|
|
// 執行更新
|
|
if err := configs.ORMDB().Save(&model).Error; err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 返回更新後的數據
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(model))
|
|
}
|
|
|
|
// 刪除模型
|
|
func ModelItemDelete(w http.ResponseWriter, r *http.Request) {
|
|
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(model))
|
|
}
|
|
|
|
// 添加一条喜欢
|
|
func ModelsItemLike(w http.ResponseWriter, r *http.Request) {
|
|
models.AccountRead(w, r, func(account *models.Account) {
|
|
// 先检查模型是否存在
|
|
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
// 添加喜欢
|
|
models.LikeModel.Add(strconv.Itoa(account.ID), strconv.Itoa(model.ID))
|
|
w.Write([]byte("ok"))
|
|
})
|
|
}
|
|
|
|
// 移除一条喜欢
|
|
func ModelsItemUnlike(w http.ResponseWriter, r *http.Request) {
|
|
models.AccountRead(w, r, func(account *models.Account) {
|
|
models.LikeModel.Remove(strconv.Itoa(account.ID), mux.Vars(r)["id"])
|
|
w.Write([]byte("ok"))
|
|
})
|
|
}
|