408 lines
15 KiB
Go
408 lines
15 KiB
Go
package models
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/md5"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"main/configs"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"encoding/base64"
|
|
"image/png"
|
|
|
|
"github.com/chai2010/webp"
|
|
"github.com/zhshch2002/goreq"
|
|
)
|
|
|
|
type Model struct {
|
|
ID int `json:"id" gorm:"primary_key"` // 模型ID
|
|
Name string `json:"name"` // 模型名稱
|
|
Info string `json:"info"` // 模型描述
|
|
Type string `json:"type"` // 模型類型(lora|ckp|hyper|ti)
|
|
TriggerWords string `json:"trigger_words"` // 觸發詞
|
|
BaseModel string `json:"base_model"` // 基礎模型(SD1.5|SD2)
|
|
ModelPath string `json:"model_path"` // 模型路徑(實際存放在服務器上的路徑)
|
|
Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error)
|
|
Progress int `json:"progress"` // (0-100)
|
|
Image string `json:"image"` // 封面圖片實際地址
|
|
Hash string `json:"hash"` // 模型哈希值
|
|
Epochs int `json:"epochs"` // 訓練步數
|
|
LearningRate float32 `json:"learning_rate"` // 學習率(0.000005)
|
|
Tags TagList `json:"tags"` // 模型標籤(標籤名數組)
|
|
UserID int `json:"user_id"` // 模型的所有者
|
|
DatasetID int `json:"dataset_id"` // 模型所使用的數據集ID
|
|
ServerID string `json:"server_id"` // 模型所在服務器(訓練機或推理機)
|
|
Stars StarList `json:"stars"` // 模型的收藏者
|
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
|
}
|
|
|
|
// 创建一个带缓冲的通道,缓冲区大小为 10
|
|
// var ch = make(chan int, 10)
|
|
|
|
func init() {
|
|
configs.ORMDB().AutoMigrate(&Model{})
|
|
|
|
// 检查 images 目录是否存在, 不存在则创建
|
|
if _, err := os.Stat("data/images"); err != nil {
|
|
if err := os.MkdirAll("data/images", 0777); err != nil {
|
|
log.Println(err)
|
|
}
|
|
}
|
|
|
|
// 处理推理任务
|
|
//go func() {
|
|
// for {
|
|
// // 从通道中取出一个数据
|
|
// model := <-ch
|
|
// // 模型状态变化时, 向监听此模型的所有连接发送消息
|
|
// }
|
|
//}()
|
|
}
|
|
|
|
func (model *Model) Inference(image_list []Image, callback func(Image)) {
|
|
// 模型未部署到推理機
|
|
if model.ServerID == "" {
|
|
//log.Println("模型未部署到推理機, 开始部署模型")
|
|
log.Println("模型已部署到推理機, 开始推理模型")
|
|
|
|
var server Server
|
|
server.IP = "106.15.192.42"
|
|
server.Port = 7860
|
|
//if err := configs.ORMDB().Where("models LIKE ?", "%"+model.Name+"%").Take(&server).Error; err != nil {
|
|
// log.Println(err)
|
|
// // 如果没有则寻找空闲服务器
|
|
// // 如果没有空闲则创建新服务器
|
|
// // 取一台空闲的推理机上传并切换到此模型
|
|
// // 新建一台推理机上传并切换到此模型
|
|
//}
|
|
|
|
// 执行生成任务
|
|
if model.Image == "" {
|
|
// 发送的参数
|
|
var data = struct {
|
|
//EnableHr bool `json:"enable_hr"`
|
|
//DenoisingStrength int `json:"denoising_strength"`
|
|
//FirstphaseWidth int `json:"firstphase_width"`
|
|
//FirstphaseHeight int `json:"firstphase_height"`
|
|
//HrScale int `json:"hr_scale"`
|
|
//HrUpscaler string `json:"hr_upscaler"`
|
|
//HrSecondPassSteps int `json:"hr_second_pass_steps"`
|
|
//HrResizeX int `json:"hr_resize_x"`
|
|
//HrResizeY int `json:"hr_resize_y"`
|
|
//HrSamplerName string `json:"hr_sampler_name"`
|
|
//HrPrompt string `json:"hr_prompt"`
|
|
//HrNegativePrompt string `json:"hr_negative_prompt"`
|
|
Prompt string `json:"prompt"`
|
|
//Styles []string `json:"styles"`
|
|
Seed int `json:"seed"`
|
|
//Subseed int `json:"subseed"`
|
|
//SubseedStrength int `json:"subseed_strength"`
|
|
//SeedResizeFromH int `json:"seed_resize_from_h"`
|
|
//SeedResizeFromW int `json:"seed_resize_from_w"`
|
|
SamplerName string `json:"sampler_name"`
|
|
//BatchSize int `json:"batch_size"`
|
|
NIter int `json:"n_iter"`
|
|
Steps int `json:"steps"`
|
|
CfgScale int `json:"cfg_scale"`
|
|
//Width int `json:"width"`
|
|
//Height int `json:"height"`
|
|
//RestoreFaces bool `json:"restore_faces"`
|
|
//Tiling bool `json:"tiling"`
|
|
//DoNotSaveSamples bool `json:"do_not_save_samples"`
|
|
//DoNotSaveGrid bool `json:"do_not_save_grid"`
|
|
//NegativePrompt string `json:"negative_prompt"`
|
|
//Eta int `json:"eta"`
|
|
//SMinUncond int `json:"s_min_uncond"`
|
|
//SChurn int `json:"s_churn"`
|
|
//STmax int `json:"s_tmax"`
|
|
//STmin int `json:"s_tmin"`
|
|
//SNoise int `json:"s_noise"`
|
|
//OverrideSettings map[string]string `json:"override_settings"`
|
|
//OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"`
|
|
//ScriptArgs []interface{} `json:"script_args"`
|
|
//SamplerIndex string `json:"sampler_index"`
|
|
//ScriptName string `json:"script_name"`
|
|
//SendImages bool `json:"send_images"`
|
|
//SaveImages bool `json:"save_images"`
|
|
//AlwaysonScripts map[string]string `json:"alwayson_scripts"`
|
|
}{
|
|
//EnableHr: false,
|
|
//DenoisingStrength: 0,
|
|
//FirstphaseWidth: 0,
|
|
//FirstphaseHeight: 0,
|
|
//HrScale: 2,
|
|
//HrUpscaler: "nearest",
|
|
//HrSecondPassSteps: 0,
|
|
//HrResizeX: 0,
|
|
//HrResizeY: 0,
|
|
//HrSamplerName: "",
|
|
//HrPrompt: "",
|
|
//HrNegativePrompt: "",
|
|
Prompt: image_list[0].Prompt,
|
|
//Styles: []string{},
|
|
Seed: image_list[0].Seed,
|
|
//Subseed: -1,
|
|
//SubseedStrength: 0,
|
|
//SeedResizeFromH: -1,
|
|
//SeedResizeFromW: -1,
|
|
SamplerName: image_list[0].SamplerName, // 采样器名称
|
|
//BatchSize: 1,
|
|
NIter: len(image_list), // 1~100
|
|
Steps: 50, // 1~150
|
|
CfgScale: image_list[0].CfgScale,
|
|
//Width: 512,
|
|
//Height: 512,
|
|
//RestoreFaces: false,
|
|
//Tiling: false,
|
|
//DoNotSaveSamples: false,
|
|
//DoNotSaveGrid: false,
|
|
//NegativePrompt: "",
|
|
//Eta: 0,
|
|
//SMinUncond: 0,
|
|
//SChurn: 0,
|
|
//STmax: 0,
|
|
//STmin: 0,
|
|
//SNoise: 1,
|
|
//OverrideSettings: map[string]string{},
|
|
//OverrideSettingsRestoreAfterwards: false,
|
|
//ScriptArgs: []interface{}{},
|
|
//SamplerIndex: "Euler",
|
|
//ScriptName: "generate",
|
|
//SendImages: true,
|
|
//SaveImages: false,
|
|
//AlwaysonScripts: map[string]string{},
|
|
}
|
|
// 接收到的图片列表
|
|
var rest = struct {
|
|
Images []string `json:"images"`
|
|
}{}
|
|
var url = fmt.Sprintf("http://%s:%d/sdapi/v1/txt2img", server.IP, server.Port)
|
|
if err := goreq.Post(url).SetJsonBody(data).Do().BindJSON(&rest); err != nil {
|
|
log.Println("API 查询失败:", err)
|
|
}
|
|
for index, img := range rest.Images {
|
|
var filename = fmt.Sprintf("%x", md5.Sum([]byte(img)))
|
|
log.Println("保存图片:", filename)
|
|
if err := SaveBase64Image(img, "data/images/"+filename+".webp"); err != nil {
|
|
log.Println(err)
|
|
}
|
|
image_list[index].Name = filename
|
|
image_list[index].Path = "data/images/" + filename + ".webp"
|
|
image_list[index].Hash = filename
|
|
image_list[index].Type = "image/webp"
|
|
image_list[index].Width = 512
|
|
image_list[index].Height = 512
|
|
image_list[index].Format = "webp"
|
|
image_list[index].Status = "success"
|
|
image_list[index].Progress = 100
|
|
callback(image_list[index])
|
|
}
|
|
}
|
|
return
|
|
}
|
|
log.Println("模型未部署到推理機, 取消推理模型")
|
|
}
|
|
|
|
// 将base64编码的图片保存到本地webp
|
|
func SaveBase64Image(base64Str string, filename string) error {
|
|
// 解码base64图片
|
|
data, err := base64.StdEncoding.DecodeString(base64Str)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 将png图片解码为image.Image
|
|
img, err := png.Decode(bytes.NewReader(data))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 创建webp编码器
|
|
webpWriter, err := os.Create(filename)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer webpWriter.Close()
|
|
|
|
// 将image.Image编码为webp格式并保存到本地
|
|
if err := webp.Encode(webpWriter, img, &webp.Options{Lossless: true}); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (model *Model) Train() (err error) {
|
|
|
|
// 獲取一臺空閒的訓練機
|
|
var server Server
|
|
if err = configs.ORMDB().Where("status = ?", "正常").First(&server).Error; err != nil {
|
|
fmt.Println(err)
|
|
// TOOD: 沒有空閒的訓練機, 訓練排隊, 等待訓練機空閒
|
|
// TODO: 如果訓練機數量低於10臺, 則創建新的訓練機
|
|
return
|
|
}
|
|
|
|
// 獲取數據集
|
|
var dataset Dataset = Dataset{ID: model.DatasetID}
|
|
if err = configs.ORMDB().First(&dataset).Error; err != nil {
|
|
fmt.Println(err)
|
|
return
|
|
}
|
|
|
|
// 更新模型狀態
|
|
model.ServerID = server.ID
|
|
model.Status = "training"
|
|
if err = configs.ORMDB().Save(&model).Error; err != nil {
|
|
fmt.Println(err)
|
|
return
|
|
}
|
|
|
|
// 創建數據集目錄
|
|
dirPath := filepath.Join("data/datasets", fmt.Sprint(dataset.ID), "images")
|
|
if err := os.MkdirAll(dirPath, 0755); err != nil {
|
|
fmt.Println(err)
|
|
return err
|
|
}
|
|
|
|
// 將數據下載到本地
|
|
for index, url := range dataset.Images {
|
|
fmt.Println("下載數據到本地:", index, url)
|
|
|
|
// 檢查文件是否已經存在
|
|
filename := fmt.Sprintf("%x", md5.Sum([]byte(url)))
|
|
filePath := filepath.Join(dirPath, filename)
|
|
if _, err := os.Stat(filePath); err == nil {
|
|
fmt.Println("文件已經存在:", filePath)
|
|
continue
|
|
}
|
|
|
|
// 下載到臨時目錄
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
fmt.Println("下載失敗:", err)
|
|
continue
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
data, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
fmt.Println("保存失敗:", err)
|
|
continue
|
|
}
|
|
|
|
// 保存文件到本地目錄下
|
|
if err := ioutil.WriteFile(filePath, data, 0644); err != nil {
|
|
fmt.Println(err)
|
|
continue
|
|
}
|
|
|
|
}
|
|
|
|
fmt.Println("數據下載完成")
|
|
|
|
// 檢查目錄下是否有文件, 如果沒有文件則返回錯誤
|
|
files, err := ioutil.ReadDir(dirPath)
|
|
if err != nil || len(files) == 0 {
|
|
fmt.Println("目錄下沒有文件")
|
|
return fmt.Errorf("目錄下沒有文件")
|
|
}
|
|
|
|
// 將數據上傳到訓練機
|
|
|
|
// 按類型執行訓練任務
|
|
if model.Type == "dreambooth" {
|
|
// 創建數據庫模型
|
|
fmt.Println("創建數據庫模型 ======================================")
|
|
resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/dreambooth/createModel?new_model_name=%s&new_model_src=%s", server.IP, server.Port, model.Name, model.ModelPath), nil)
|
|
if err != nil {
|
|
fmt.Println("創建訓練任務失敗:", err.Error())
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// 打印返回的結果
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
fmt.Println("解碼任務數據失敗:", err)
|
|
return err
|
|
}
|
|
fmt.Println("預覽:", string(body))
|
|
|
|
// 上傳數據到訓練機
|
|
|
|
// 執行訓練命令
|
|
}
|
|
|
|
if model.Type == "lora" {
|
|
// 創建數據庫模型
|
|
formData := url.Values{}
|
|
formData.Set("name", model.Name)
|
|
formData.Set("type", model.Type)
|
|
resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/lora/createModel", server.IP, server.Port), formData)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// 上傳數據到訓練機
|
|
|
|
// 執行訓練命令
|
|
}
|
|
|
|
//// 將文件全部上傳到訓練機, 使用scp命令,自動使用密碼登錄
|
|
//err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s@%s:~/dataset_%d", server.UserName, server.IP, model.ID)).Run()
|
|
//if err != nil {
|
|
// fmt.Println(err)
|
|
// return err
|
|
//}
|
|
//// 刪除本地臨時目錄
|
|
//if err := os.RemoveAll(dirPath); err != nil {
|
|
// fmt.Println(err)
|
|
// return err
|
|
//}
|
|
//// 将基础模型上传到训练机(使用scp命令)
|
|
//baseModelPath := filepath.Join("data/models", model.BaseModel)
|
|
//fmt.Println("baseModelPath:", baseModelPath)
|
|
//err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
|
|
//if err != nil {
|
|
// fmt.Println(err)
|
|
// return err
|
|
//}
|
|
//// 進行訓練(訓練機上調用訓練webapi接口:參數)
|
|
//resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
|
|
//if err != nil {
|
|
// fmt.Println(err)
|
|
// return err
|
|
//}
|
|
//defer resp.Body.Close()
|
|
//// 循環監聽訓練進度
|
|
//for i := 0; i < 5; i++ {
|
|
// // 訓練機上調用訓練webapi接口:獲取訓練進度
|
|
// resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
|
|
// if err != nil {
|
|
// fmt.Println(err)
|
|
// return err
|
|
// }
|
|
// defer resp.Body.Close()
|
|
//// 更新本地訓練進度
|
|
// var progress int
|
|
// if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
|
|
// fmt.Println(err)
|
|
// return err
|
|
// }
|
|
//}
|
|
//
|
|
// TODO: 訓練完成後將模型下載到本地
|
|
return nil
|
|
|
|
}
|