生成图像(默认参数)

This commit is contained in:
2023-06-21 15:40:00 +08:00
parent 6f06c701ad
commit 2a71384fad
8 changed files with 324 additions and 86 deletions

View File

@@ -1,15 +1,23 @@
package models
import (
"bytes"
"crypto/md5"
"fmt"
"io/ioutil"
"log"
"main/configs"
"net/http"
"net/url"
"os"
"path/filepath"
"time"
"encoding/base64"
"image/png"
"github.com/chai2010/webp"
"github.com/zhshch2002/goreq"
)
type Model struct {
@@ -35,8 +43,184 @@ type Model struct {
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// 创建一个带缓冲的通道,缓冲区大小为 10
// var ch = make(chan int, 10)
func init() {
configs.ORMDB().AutoMigrate(&Model{})
// 处理推理任务
//go func() {
// for {
// // 从通道中取出一个数据
// model := <-ch
// // 模型状态变化时, 向监听此模型的所有连接发送消息
// }
//}()
}
func (model *Model) Inference(image_list []Image, callback func()) {
log.Println(image_list)
callback()
// 模型未部署到推理機
if model.ServerID == "" {
log.Println("模型未部署到推理機, 开始部署模型")
var server Server
if err := configs.ORMDB().Where("models LIKE ?", "%"+model.Name+"%").Take(&server).Error; err != nil {
log.Println(err)
// 如果没有则寻找空闲服务器
// 如果没有空闲则创建新服务器
// 取一台空闲的推理机上传并切换到此模型
// 新建一台推理机上传并切换到此模型
}
// 执行生成任务
if model.Image == "" {
var data = struct {
EnableHr bool `json:"enable_hr"`
DenoisingStrength int `json:"denoising_strength"`
FirstphaseWidth int `json:"firstphase_width"`
FirstphaseHeight int `json:"firstphase_height"`
HrScale int `json:"hr_scale"`
HrUpscaler string `json:"hr_upscaler"`
HrSecondPassSteps int `json:"hr_second_pass_steps"`
HrResizeX int `json:"hr_resize_x"`
HrResizeY int `json:"hr_resize_y"`
HrSamplerName string `json:"hr_sampler_name"`
HrPrompt string `json:"hr_prompt"`
HrNegativePrompt string `json:"hr_negative_prompt"`
Prompt string `json:"prompt"`
Styles []string `json:"styles"`
Seed int `json:"seed"`
Subseed int `json:"subseed"`
SubseedStrength int `json:"subseed_strength"`
SeedResizeFromH int `json:"seed_resize_from_h"`
SeedResizeFromW int `json:"seed_resize_from_w"`
SamplerName string `json:"sampler_name"`
BatchSize int `json:"batch_size"`
NIter int `json:"n_iter"`
Steps int `json:"steps"`
CfgScale int `json:"cfg_scale"`
Width int `json:"width"`
Height int `json:"height"`
RestoreFaces bool `json:"restore_faces"`
Tiling bool `json:"tiling"`
DoNotSaveSamples bool `json:"do_not_save_samples"`
DoNotSaveGrid bool `json:"do_not_save_grid"`
NegativePrompt string `json:"negative_prompt"`
Eta int `json:"eta"`
SMinUncond int `json:"s_min_uncond"`
SChurn int `json:"s_churn"`
STmax int `json:"s_tmax"`
STmin int `json:"s_tmin"`
SNoise int `json:"s_noise"`
OverrideSettings map[string]string `json:"override_settings"`
OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"`
ScriptArgs []interface{} `json:"script_args"`
SamplerIndex string `json:"sampler_index"`
ScriptName string `json:"script_name"`
SendImages bool `json:"send_images"`
SaveImages bool `json:"save_images"`
AlwaysonScripts map[string]string `json:"alwayson_scripts"`
}{
EnableHr: false,
DenoisingStrength: 0,
FirstphaseWidth: 0,
FirstphaseHeight: 0,
HrScale: 2,
HrUpscaler: "nearest",
HrSecondPassSteps: 0,
HrResizeX: 0,
HrResizeY: 0,
HrSamplerName: "",
HrPrompt: "",
HrNegativePrompt: "",
Prompt: "miao~",
Styles: []string{},
Seed: -1,
Subseed: -1,
SubseedStrength: 0,
SeedResizeFromH: -1,
SeedResizeFromW: -1,
SamplerName: "beamsearch",
BatchSize: 1,
NIter: 1,
Steps: 50,
CfgScale: 7,
Width: 512,
Height: 512,
RestoreFaces: false,
Tiling: false,
DoNotSaveSamples: false,
DoNotSaveGrid: false,
NegativePrompt: "",
Eta: 0,
SMinUncond: 0,
SChurn: 0,
STmax: 0,
STmin: 0,
SNoise: 1,
OverrideSettings: map[string]string{},
OverrideSettingsRestoreAfterwards: false,
ScriptArgs: []interface{}{},
SamplerIndex: "Euler",
ScriptName: "generate",
SendImages: true,
SaveImages: false,
AlwaysonScripts: map[string]string{},
}
// 接收到的图片列表
var rest = struct {
Images []string `json:"images"`
}{
Images: []string{},
}
var url = fmt.Sprintf("http://%s:%d/sdapi/v1/txt2img", server.IP, server.Port)
if err := goreq.Post(url).SetJsonBody(data).Do().BindJSON(&rest); err != nil {
log.Println("API 查询失败:", err)
}
log.Println("API 查询成功:", rest)
for _, img := range rest.Images {
log.Println("保存图片:", img)
// 将base64编码的图片保存到本地webp
if err := SaveBase64Image(img, "data/images/"+img+".webp"); err != nil {
log.Println(err)
}
}
}
}
}
// 将base64编码的图片保存到本地webp
func SaveBase64Image(base64Str string, filename string) error {
// 解码base64图片
data, err := base64.StdEncoding.DecodeString(base64Str)
if err != nil {
return err
}
// 将png图片解码为image.Image
img, err := png.Decode(bytes.NewReader(data))
if err != nil {
return err
}
// 创建webp编码器
webpWriter, err := os.Create(filename)
if err != nil {
return err
}
defer webpWriter.Close()
// 将image.Image编码为webp格式并保存到本地
if err := webp.Encode(webpWriter, img, &webp.Options{Lossless: true}); err != nil {
return err
}
return nil
}
func (model *Model) Train() (err error) {

View File

@@ -3,59 +3,44 @@ package models
import (
"sync"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type WebSocketManager struct {
connections map[string]*websocket.Conn
listeners map[string]map[chan struct{}]struct{}
connections map[*websocket.Conn]string // 连接指针:任务ID
mutex sync.RWMutex
}
// 创建一个新的连接池
func NewWebSocketManager() *WebSocketManager {
return &WebSocketManager{
connections: make(map[string]*websocket.Conn),
connections: make(map[*websocket.Conn]string),
mutex: sync.RWMutex{},
}
}
func (mgr *WebSocketManager) AddConnection(conn *websocket.Conn) string {
// 向连接池加入一个新连接
func (mgr *WebSocketManager) AddConnection(conn *websocket.Conn, task string) {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()
id := uuid.New().String() // 为每个连接生成一个唯一的 ID
mgr.connections[id] = conn
return id
mgr.connections[conn] = task
}
func (mgr *WebSocketManager) RemoveConnection(id string) {
// 从连接池中移除一个连接
func (mgr *WebSocketManager) RemoveConnection(conn *websocket.Conn) {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()
delete(mgr.connections, id)
delete(mgr.connections, conn)
}
func (mgr *WebSocketManager) ListenForChanges(target string, callback func()) {
notifications := make(chan struct{})
// 任务状态变化时, 向监听此任务的所有连接发送消息
func (mgr *WebSocketManager) NotifyTaskChange(task string, data interface{}) {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()
if _, ok := mgr.listeners[target]; !ok {
mgr.listeners[target] = make(map[chan struct{}]struct{})
}
mgr.listeners[target][notifications] = struct{}{}
go func() {
for {
callback()
for listener := range mgr.listeners[target] {
select {
case listener <- struct{}{}:
default:
delete(mgr.listeners[target], listener)
}
}
for conn, value := range mgr.connections {
if value == task {
conn.WriteJSON(data)
}
}()
}
}

View File

@@ -1,6 +1,7 @@
package models
import (
"database/sql/driver"
"encoding/json"
"fmt"
"main/configs"
@@ -8,23 +9,40 @@ import (
"time"
)
type Server struct {
ID string `json:"id" gorm:"primary_key"`
Name string `json:"name"`
Type string `json:"type"` // (訓練|推理)
IP string `json:"ip"`
Port int `json:"port"`
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
UserName string `json:"username"`
Password string `json:"password"`
Models []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
type ModelList []string
func (list *ModelList) Scan(value interface{}) error {
return json.Unmarshal(value.([]byte), list)
}
func (list ModelList) Value() (driver.Value, error) {
return json.Marshal(list)
}
type Server struct {
ID string `json:"id" gorm:"primary_key"`
Name string `json:"name"`
Type string `json:"type"` // (训练|推理)
IP string `json:"ip"`
Port int `json:"port"`
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
UserName string `json:"username"`
Password string `json:"password"`
Models ModelList `json:"models"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// 获取所有服务器
func GetServers() (servers []Server, err error) {
err = configs.ORMDB().Find(&servers).Error
return
}
// 檢查服務器是否正常
func (server *Server) CheckStatus() error {
switch server.Type {
case "訓練":
case "训练":
resp, err := http.Get(fmt.Sprintf("http://%s:%d/dreambooth/status", server.IP, server.Port))
if err != nil {
server.Status = "異常"