Files
ai/models/Model.go
2023-08-13 06:28:30 +08:00

506 lines
18 KiB
Go

package models
import (
"bytes"
"crypto/md5"
"fmt"
"io/ioutil"
"log"
"main/configs"
"net/http"
"net/url"
"os"
"path/filepath"
"time"
"encoding/base64"
"encoding/json"
"image/png"
"github.com/chai2010/webp"
"github.com/zhshch2002/goreq"
)
type Model struct {
ID int `json:"id" gorm:"primary_key"` // 模型ID
Name string `json:"name"` // 模型名稱
ModelCheckpoint string `json:"model_checkpoint"` // 模型檢查點
Info string `json:"info"` // 模型描述
Type string `json:"type"` // 模型類型(lora|ckp|hyper|ti)
TriggerWords string `json:"trigger_words"` // 觸發詞
BaseModel string `json:"base_model"` // 基礎模型(SD1.5|SD2)
ModelPath string `json:"model_path"` // 模型路徑(實際存放在服務器上的路徑)
Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error|public)
Progress int `json:"progress"` // (0-100)
Preview string `json:"preview"` // 模型預覽圖片
Hash string `json:"hash"` // 模型哈希值(sha256)
Epochs int `json:"epochs"` // 訓練步數
LearningRate float32 `json:"learning_rate"` // 學習率(0.000005)
Tags TagList `json:"tags"` // 模型標籤(標籤名數組)
UserID int `json:"user_id"` // 模型的所有者
DatasetID int `json:"dataset_id"` // 模型所使用的數據集ID
ServerID string `json:"server_id"` // 模型所在服務器(訓練機或推理機)
Stars StarList `json:"stars"` // 模型的收藏者
User *User `json:"user" gorm:"foreignKey:UserID;"` // 模型的所有者
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
func init() {
configs.ORMDB().AutoMigrate(&Model{})
if _, err := os.Stat("data/images"); err != nil {
if err := os.MkdirAll("data/images", 0777); err != nil {
log.Println(err)
}
}
// 清除所有hash长度小于32的模型
configs.ORMDB().Where("length(hash) < 32").Delete(&Model{})
// 清除所有type为空的模型
configs.ORMDB().Where("type = ?", "").Delete(&Model{})
}
// 从数据库加载
func (model *Model) Load() {
configs.ORMDB().First(&model)
}
// 从数据库加载指定的模型
func ModelLoad(id int) (model Model, err error) {
err = configs.ORMDB().First(&model, id).Error
return
}
// 推理模型
func (model *Model) Inference(image_list []Image, callback func(Image)) {
var server Server
// 模型未部署到推理機
if model.ServerID == "" {
log.Println("模型未部署到推理機, 开始部署模型")
// 寻找一台就绪的且模型位置仍有空余的推理机
if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("length(models) < ?", 5).First(&server).Error; err != nil {
log.Println("创建一台新的推理机: 当前禁止创建新服务器")
return
}
// 打印为格式化的json
data, _ := json.MarshalIndent(server, "", " ")
fmt.Println(string(data))
// TODO: 上传模型到推理机
// 记录到推理机
server.Models = append(server.Models, model.ID)
configs.ORMDB().Save(&server)
// 记录到模型
model.ServerID = server.ID
configs.ORMDB().Save(&model)
} else {
server.ID = model.ServerID
configs.ORMDB().Take(&server)
}
// 检查推理机是否已经加载了模型
if server.ModelID != model.ID {
log.Println("推理机未加载模型, 开始排队加载模型")
// 通知关注此任务的用户
for _, img := range image_list {
img.Status = "loading"
img.Preview = "正在切换模型"
callback(img)
}
// 执行切换模型(推理机需要先处理完当前的任务才能加载新的模型)
if err := goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).SetJsonBody(map[string]interface{}{
"sd_model_checkpoint": model.ModelCheckpoint,
"CLIP_stop_at_last_layers": 2,
}).Do().Error(); err != nil {
log.Println("切换模型失败:", err)
return
}
var form = struct {
SdCheckpointHash string `json:"sd_checkpoint_hash"`
SdModelCheckpoint string `json:"sd_model_checkpoint"`
}{}
// 超时时间 1分钟
var timeout = time.Now().Add(time.Second * 60)
for {
if err := goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).Do().BindJSON(&form); err != nil {
log.Println("获取推理机配置失败:", err)
return
}
if form.SdModelCheckpoint == model.ModelCheckpoint {
log.Println("模型切换完成:", form.SdModelCheckpoint)
break
}
if time.Now().After(timeout) {
log.Println("模型切换超时:", form.SdModelCheckpoint)
// 通知关注此任务的用户
for _, img := range image_list {
img.Status = "error"
img.Preview = "模型切换超时"
callback(img)
}
return
}
time.Sleep(time.Second)
}
// 更新推理机模型ID
server.ModelID = model.ID
if err := configs.ORMDB().Save(&server).Error; err != nil {
log.Println("更新推理机模型ID失败:", err)
return
}
}
// 发送的参数
var img = image_list[0]
var datx map[string]interface{} = make(map[string]interface{})
datx["prompt"] = img.Prompt // 提示词
datx["seed"] = img.Seed // 随机数种子
datx["n_iter"] = len(image_list) // 生成图像数量
datx["steps"] = img.Steps // 迭代步数
datx["cfg_scale"] = img.CfgScale // 提示词引导系数 (CFG Scale)
datx["width"] = img.Width // 图片宽度
datx["height"] = img.Height // 图片高度
if img.SamplerName == "" {
datx["sampler_name"] = img.SamplerName // 采样器名称
}
if img.NegativePrompt != "" {
datx["negative_prompt"] = img.NegativePrompt // 负面提示词
}
fmt.Println("image_list:", datx)
// 接收到的图片列表
var rest = struct {
Images []string `json:"images"`
}{}
var url = fmt.Sprintf("http://%s:%d/sdapi/v1/txt2img", server.IP, server.Port)
if err := goreq.Post(url).SetJsonBody(datx).Do().BindJSON(&rest); err != nil {
log.Println("API 查询失败:", err)
}
for index, img := range rest.Images {
var filename = fmt.Sprintf("%x", md5.Sum([]byte(img)))
log.Println("保存图片:", filename)
if err := SaveBase64Image(img, "data/images/"+filename+".webp"); err != nil {
log.Println(err)
}
image_list[index].Name = filename
image_list[index].Path = "data/images/" + filename + ".webp"
image_list[index].Hash = filename
image_list[index].Type = "image/webp"
image_list[index].Format = "webp"
image_list[index].Status = "success"
image_list[index].Progress = 100
//image_list[index].Preview = img
callback(image_list[index])
}
log.Println("推理完成:", model.ID, model.Name)
}
// 加入推理任务
// 将base64编码的图片保存到本地webp
func SaveBase64Image(base64Str string, filename string) error {
// 解码base64图片
data, err := base64.StdEncoding.DecodeString(base64Str)
if err != nil {
return err
}
// 将png图片解码为image.Image
img, err := png.Decode(bytes.NewReader(data))
if err != nil {
return err
}
// 创建webp编码器
webpWriter, err := os.Create(filename)
if err != nil {
return err
}
defer webpWriter.Close()
// 将image.Image编码为webp格式并保存到本地
if err := webp.Encode(webpWriter, img, &webp.Options{Lossless: true}); err != nil {
return err
}
return nil
}
// 训练模型
func (model *Model) Train() (err error) {
// 獲取一臺空閒的訓練機
var server Server
if err = configs.ORMDB().Where("status = ?", "正常").First(&server).Error; err != nil {
fmt.Println(err)
// TOOD: 沒有空閒的訓練機, 訓練排隊, 等待訓練機空閒
// TODO: 如果訓練機數量低於10臺, 則創建新的訓練機
return
}
// 獲取數據集
var dataset Dataset = Dataset{ID: model.DatasetID}
if err = configs.ORMDB().First(&dataset).Error; err != nil {
fmt.Println(err)
return
}
// 更新模型狀態
model.ServerID = server.ID
model.Status = "training"
if err = configs.ORMDB().Save(&model).Error; err != nil {
fmt.Println(err)
return
}
// 創建數據集目錄
dirPath := filepath.Join("data/datasets", fmt.Sprint(dataset.ID), "images")
if err := os.MkdirAll(dirPath, 0755); err != nil {
fmt.Println(err)
return err
}
// 將數據下載到本地
for index, url := range dataset.Images {
fmt.Println("下載數據到本地:", index, url)
// 檢查文件是否已經存在
filename := fmt.Sprintf("%x", md5.Sum([]byte(url)))
filePath := filepath.Join(dirPath, filename)
if _, err := os.Stat(filePath); err == nil {
fmt.Println("文件已經存在:", filePath)
continue
}
// 下載到臨時目錄
resp, err := http.Get(url)
if err != nil {
fmt.Println("下載失敗:", err)
continue
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
fmt.Println("保存失敗:", err)
continue
}
// 保存文件到本地目錄下
if err := ioutil.WriteFile(filePath, data, 0644); err != nil {
fmt.Println(err)
continue
}
}
fmt.Println("數據下載完成")
// 檢查目錄下是否有文件, 如果沒有文件則返回錯誤
files, err := ioutil.ReadDir(dirPath)
if err != nil || len(files) == 0 {
fmt.Println("目錄下沒有文件")
return fmt.Errorf("目錄下沒有文件")
}
// 將數據上傳到訓練機
// 按類型執行訓練任務
if model.Type == "dreambooth" {
// 創建數據庫模型
fmt.Println("創建數據庫模型 ======================================")
resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/dreambooth/createModel?new_model_name=%s&new_model_src=%s", server.IP, server.Port, model.Name, model.ModelPath), nil)
if err != nil {
fmt.Println("創建訓練任務失敗:", err.Error())
return err
}
defer resp.Body.Close()
// 打印返回的結果
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
fmt.Println("解碼任務數據失敗:", err)
return err
}
fmt.Println("預覽:", string(body))
// 上傳數據到訓練機
// 執行訓練命令
}
if model.Type == "lora" {
// 創建數據庫模型
formData := url.Values{}
formData.Set("name", model.Name)
formData.Set("type", model.Type)
resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/lora/createModel", server.IP, server.Port), formData)
if err != nil {
fmt.Println(err)
return err
}
defer resp.Body.Close()
// 上傳數據到訓練機
// 執行訓練命令
}
//// 將文件全部上傳到訓練機, 使用scp命令,自動使用密碼登錄
//err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s@%s:~/dataset_%d", server.UserName, server.IP, model.ID)).Run()
//if err != nil {
// fmt.Println(err)
// return err
//}
//// 刪除本地臨時目錄
//if err := os.RemoveAll(dirPath); err != nil {
// fmt.Println(err)
// return err
//}
//// 将基础模型上传到训练机(使用scp命令)
//baseModelPath := filepath.Join("data/models", model.BaseModel)
//fmt.Println("baseModelPath:", baseModelPath)
//err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
//if err != nil {
// fmt.Println(err)
// return err
//}
//// 進行訓練(訓練機上調用訓練webapi接口:參數)
//resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
//if err != nil {
// fmt.Println(err)
// return err
//}
//defer resp.Body.Close()
//// 循環監聽訓練進度
//for i := 0; i < 5; i++ {
// // 訓練機上調用訓練webapi接口:獲取訓練進度
// resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
// if err != nil {
// fmt.Println(err)
// return err
// }
// defer resp.Body.Close()
//// 更新本地訓練進度
// var progress int
// if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
// fmt.Println(err)
// return err
// }
//}
//
// TODO: 訓練完成後將模型下載到本地
return nil
}
/**
var data = struct {
//EnableHr bool `json:"enable_hr"`
//DenoisingStrength int `json:"denoising_strength"`
//FirstphaseWidth int `json:"firstphase_width"`
//FirstphaseHeight int `json:"firstphase_height"`
//HrScale int `json:"hr_scale"`
//HrUpscaler string `json:"hr_upscaler"`
//HrSecondPassSteps int `json:"hr_second_pass_steps"`
//HrResizeX int `json:"hr_resize_x"`
//HrResizeY int `json:"hr_resize_y"`
//HrSamplerName string `json:"hr_sampler_name"`
//HrPrompt string `json:"hr_prompt"`
//HrNegativePrompt string `json:"hr_negative_prompt"`
Prompt string `json:"prompt"`
//Styles []string `json:"styles"`
Seed int `json:"seed"`
//Subseed int `json:"subseed"`
//SubseedStrength int `json:"subseed_strength"`
//SeedResizeFromH int `json:"seed_resize_from_h"`
//SeedResizeFromW int `json:"seed_resize_from_w"`
SamplerName string `json:"sampler_name"`
//BatchSize int `json:"batch_size"`
NIter int `json:"n_iter"`
Steps int `json:"steps"`
CfgScale int `json:"cfg_scale"`
//Width int `json:"width"`
//Height int `json:"height"`
//RestoreFaces bool `json:"restore_faces"`
//Tiling bool `json:"tiling"`
//DoNotSaveSamples bool `json:"do_not_save_samples"`
//DoNotSaveGrid bool `json:"do_not_save_grid"`
//NegativePrompt string `json:"negative_prompt"`
//Eta int `json:"eta"`
//SMinUncond int `json:"s_min_uncond"`
//SChurn int `json:"s_churn"`
//STmax int `json:"s_tmax"`
//STmin int `json:"s_tmin"`
//SNoise int `json:"s_noise"`
//OverrideSettings map[string]string `json:"override_settings"`
//OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"`
//ScriptArgs []interface{} `json:"script_args"`
//SamplerIndex string `json:"sampler_index"`
//ScriptName string `json:"script_name"`
//SendImages bool `json:"send_images"`
//SaveImages bool `json:"save_images"`
//AlwaysonScripts map[string]string `json:"alwayson_scripts"`
}{
//EnableHr: false,
//DenoisingStrength: 0,
//FirstphaseWidth: 0,
//FirstphaseHeight: 0,
//HrScale: 2,
//HrUpscaler: "nearest",
//HrSecondPassSteps: 0,
//HrResizeX: 0,
//HrResizeY: 0,
//HrSamplerName: "",
//HrPrompt: "",
//HrNegativePrompt: "",
Prompt: image_list[0].Prompt,
//Styles: []string{},
Seed: image_list[0].Seed,
//Subseed: -1,
//SubseedStrength: 0,
//SeedResizeFromH: -1,
//SeedResizeFromW: -1,
SamplerName: image_list[0].SamplerName, // 采样器名称
//BatchSize: 1,
NIter: len(image_list), // 1~100
Steps: 50, // 1~150
CfgScale: image_list[0].CfgScale,
//Width: 512,
//Height: 512,
//RestoreFaces: false,
//Tiling: false,
//DoNotSaveSamples: false,
//DoNotSaveGrid: false,
//NegativePrompt: "",
//Eta: 0,
//SMinUncond: 0,
//SChurn: 0,
//STmax: 0,
//STmin: 0,
//SNoise: 1,
//OverrideSettings: map[string]string{},
//OverrideSettingsRestoreAfterwards: false,
//ScriptArgs: []interface{}{},
//SamplerIndex: "Euler",
//ScriptName: "generate",
//SendImages: true,
//SaveImages: false,
//AlwaysonScripts: map[string]string{},
}
fmt.Println("data:", data)
**/