Files
ai/routers/models.go
2023-08-13 06:28:30 +08:00

429 lines
11 KiB
Go

package routers
import (
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"main/configs"
"main/models"
"main/utils"
"net/http"
"os"
"regexp"
"strconv"
"github.com/gorilla/mux"
)
func init() {
models_update()
}
// 检查服务器中的模型列表
func server_models_update() {
var servers []models.Server
configs.ORMDB().Find(&servers)
fmt.Println("开始检查服务器中的模型列表")
for _, server := range servers {
fmt.Println("检查服务器中的模型列表:", server.Name)
server.InitModels()
}
fmt.Println("检查服务器中的模型列表完成")
}
// 检查本地模型列表
func models_update() {
server_models_update()
// 初始化模型路由: 检查本地模型目录是否存在, 不存在则创建
if _, err := os.Stat("data/models"); err != nil {
if err := os.MkdirAll("data/models", 0777); err != nil {
log.Println(err)
}
}
// 检查模型目录中是否存在模型文件, 如果存在且数据库中未记录, 则将模型信息写入数据库
if files, err := ioutil.ReadDir("data/models"); err == nil {
for _, file := range files {
if file.IsDir() {
continue
}
log.Println("检查模型是否存在:", file.Name())
// 检查文件是否已经存在
var model models.Model
if err := configs.ORMDB().Take(&model, "name = ?", file.Name()).Error; err == nil {
continue
}
// 计算文件的 sha256 值
f, err := os.Open("data/models/" + file.Name())
if err != nil {
log.Println(err)
continue
}
defer f.Close()
hash := sha256.New()
if _, err := io.Copy(hash, f); err != nil {
log.Println(err)
continue
}
model.Name = file.Name()
model.Hash = fmt.Sprintf("%x", hash.Sum(nil))
model.ModelPath = "data/models/" + file.Name()
model.Type = "ckp"
model.Status = "success"
model.Progress = 100
model.Tags = []string{"平台模型"}
log.Println("模型不存在, 添加到数据库:", file.Name())
configs.ORMDB().Create(&model)
}
}
}
// 更新检查本地模型列表
func ModelsUpdate(w http.ResponseWriter, r *http.Request) {
models_update()
w.Write([]byte("ok"))
}
// 獲取模型列表
func ModelsGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
var model_list []models.Model
db := configs.ORMDB()
// 按照 user_id 篩選
if user_id := utils.ParamInt(r.URL.Query().Get("user_id"), 0); user_id > 0 {
db = db.Where("user_id = ?", user_id)
}
// 按照 star 篩選
if star := utils.ParamInt(r.URL.Query().Get("star"), 0); star > 0 {
db = db.Where("stars LIKE ?", "%"+strconv.Itoa(star)+"%")
}
// 按照 name 模糊搜索
if name := r.URL.Query().Get("name"); name != "" {
db = db.Where("name LIKE ?", "%"+name+"%")
}
// 按照 type 篩選
if model_type := r.URL.Query().Get("type"); model_type != "" {
db = db.Where("type = ?", model_type)
}
// 按照 status 篩選
if status := r.URL.Query().Get("status"); status != "" {
db = db.Where("status = ?", status)
}
// 按照 tag 篩選
if tag := r.URL.Query().Get("tag"); tag != "" {
db = db.Where("tags LIKE ?", "%"+tag+"%")
}
// 获取指定用户喜欢的模型
if like := r.URL.Query().Get("like"); like != "" {
list, err := models.LikeModel.GetA(like)
if err != nil {
log.Println(err)
return
}
db = db.Where("id IN (?)", list)
}
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Preload("User").Find(&model_list).Count(&listview.Total)
listview.List = model_list
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w)
}
// 創建模型(訓練新模型)
func ModelsPost(w http.ResponseWriter, r *http.Request) {
models.AccountRead(w, r, func(account *models.Account) {
fmt.Println(account)
// 創建模型
var model models.Model
body, err := ioutil.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &model); err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
if model.Name == "" {
model.Name = utils.RandomString(8)
}
if model.Type == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("模型類型不能為空(recommend|lora|ckp|hyper|ti)"))
return
}
if model.TriggerWords == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("觸發詞不能為空"))
return
}
if model.BaseModel == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("基礎模型不能為空(SD1.5|SD2)"))
return
}
if model.Epochs <= 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("訓練輪數不能小於0"))
return
}
if model.Tags == nil {
model.Tags = []string{}
}
model.UserID = account.ID
model.Status = "initial"
if err := configs.ORMDB().Create(&model).Error; err != nil {
log.Println(err)
return
}
// 直接提交訓練任務
// go model.Train()
// 返回創建的模型
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
})
}
// 獲取模型詳情
func ModelItemGet(w http.ResponseWriter, r *http.Request) {
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(err.Error()))
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}
// 獲取模型預覽圖
func ModelsItemPreview(w http.ResponseWriter, r *http.Request) {
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
var filepath = fmt.Sprintf("data/models/%d/preview/%s", model.ID, mux.Vars(r)["filename"])
fmt.Println(filepath)
// 檢查文件是否存在
if _, err := os.Stat(filepath); err != nil {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(err.Error()))
return
}
// 返回文件
http.ServeFile(w, r, filepath)
}
// 更新模型
func ModelItemPatch(w http.ResponseWriter, r *http.Request) {
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(err.Error()))
return
}
log.Println("更新模型:", model.Name)
log.Println("Content-Type:", r.Header.Get("Content-Type"))
// 判断数据类型是否二进制文件
if regexp.MustCompile(`multipart/form-data`).MatchString(r.Header.Get("Content-Type")) {
log.Println("更新模型:", model.Name)
// 解析表单取出图片文件 (32MB)
if err := r.ParseMultipartForm(32 << 20); err != nil {
log.Println(err)
return
}
// 检查文件目录是否存在
os.MkdirAll(fmt.Sprintf("data/models/%d/preview", model.ID), 0777)
// 上传文件
for x, headers := range r.MultipartForm.File {
log.Println("x:", x)
for m, header := range headers {
log.Println("m:", m)
// 打开本地文件
file, err := os.Create(fmt.Sprintf("data/models/%d/preview/%s", model.ID, header.Filename))
if err != nil {
log.Println(err)
return
}
defer file.Close()
// 打开上传文件
f, err := header.Open()
if err != nil {
log.Println(err)
return
}
// 拷贝文件到本地
_, err = io.Copy(file, f)
if err != nil {
log.Println(err)
return
}
// 更新模型(更新预览图地址)
model.Preview = fmt.Sprintf("/api/models/%d/preview/%s", model.ID, header.Filename)
if err := configs.ORMDB().Save(&model).Error; err != nil {
log.Println(err)
return
}
}
}
// 返回更新后的数据
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
return
}
// 判断数据类型是否JSON
if r.Header.Get("Content-Type") == "application/json" {
// 取出更新数据
var model_new models.Model
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &model_new); err != nil {
log.Println(err)
return
}
// 字段不爲空且不等於原始數據時更新
if model_new.Name != "" && model_new.Name != model.Name {
model.Name = model_new.Name
}
if model_new.Info != "" && model_new.Info != model.Info {
model.Info = model_new.Info
}
if model_new.Type != "" && model_new.Type != model.Type {
model.Type = model_new.Type
}
if model_new.Status != "" && model_new.Status != model.Status {
model.Status = model_new.Status
// 如果狀態被改變爲 ready, 將模型發送到訓練隊列
if model.Status == "ready" {
model.Status = "training"
//go model.Train()
}
}
if model_new.Preview != "" && model_new.Preview != model.Preview {
model.Preview = model_new.Preview
}
if model_new.TriggerWords != "" && model_new.TriggerWords != model.TriggerWords {
model.TriggerWords = model_new.TriggerWords
}
if model_new.BaseModel != "" && model_new.BaseModel != model.BaseModel {
model.BaseModel = model_new.BaseModel
}
if model_new.ModelPath != "" && model_new.ModelPath != model.ModelPath {
model.ModelPath = model_new.ModelPath
}
if model_new.Hash != "" && model_new.Hash != model.Hash {
model.Hash = model_new.Hash
}
if model_new.Epochs != 0 && model_new.Epochs != model.Epochs {
model.Epochs = model_new.Epochs
}
if model_new.Progress != 0 && model_new.Progress != model.Progress {
model.Progress = model_new.Progress
}
if model_new.Tags != nil && len(model_new.Tags) != len(model.Tags) {
model.Tags = model_new.Tags
}
// TODO: 只允许管理员更新模型
if model_new.UserID != 0 && model_new.UserID != model.UserID {
model.UserID = model_new.UserID
}
// 執行更新
if err := configs.ORMDB().Save(&model).Error; err != nil {
log.Println(err)
return
}
// 返回更新後的數據
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
return
}
}
// 刪除模型
func ModelItemDelete(w http.ResponseWriter, r *http.Request) {
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)); err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(model))
}
// 添加一条喜欢
func ModelsItemLike(w http.ResponseWriter, r *http.Request) {
models.AccountRead(w, r, func(account *models.Account) {
// 先检查模型是否存在
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(err.Error()))
return
}
// 添加喜欢
models.LikeModel.Add(strconv.Itoa(account.ID), strconv.Itoa(model.ID))
w.Write([]byte("ok"))
})
}
// 移除一条喜欢
func ModelsItemUnlike(w http.ResponseWriter, r *http.Request) {
models.AccountRead(w, r, func(account *models.Account) {
models.LikeModel.Remove(strconv.Itoa(account.ID), mux.Vars(r)["id"])
w.Write([]byte("ok"))
})
}