生成图像(默认参数)
This commit is contained in:
		
							
								
								
									
										184
									
								
								models/Model.go
									
									
									
									
									
								
							
							
						
						
									
										184
									
								
								models/Model.go
									
									
									
									
									
								
							@@ -1,15 +1,23 @@
 | 
			
		||||
package models
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"crypto/md5"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"log"
 | 
			
		||||
	"main/configs"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"image/png"
 | 
			
		||||
 | 
			
		||||
	"github.com/chai2010/webp"
 | 
			
		||||
	"github.com/zhshch2002/goreq"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Model struct {
 | 
			
		||||
@@ -35,8 +43,184 @@ type Model struct {
 | 
			
		||||
	UpdatedAt    time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 创建一个带缓冲的通道,缓冲区大小为 10
 | 
			
		||||
// var ch = make(chan int, 10)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	configs.ORMDB().AutoMigrate(&Model{})
 | 
			
		||||
 | 
			
		||||
	// 处理推理任务
 | 
			
		||||
	//go func() {
 | 
			
		||||
	//	for {
 | 
			
		||||
	//		// 从通道中取出一个数据
 | 
			
		||||
	//		model := <-ch
 | 
			
		||||
	//		// 模型状态变化时, 向监听此模型的所有连接发送消息
 | 
			
		||||
	//	}
 | 
			
		||||
	//}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (model *Model) Inference(image_list []Image, callback func()) {
 | 
			
		||||
	log.Println(image_list)
 | 
			
		||||
	callback()
 | 
			
		||||
 | 
			
		||||
	// 模型未部署到推理機
 | 
			
		||||
	if model.ServerID == "" {
 | 
			
		||||
		log.Println("模型未部署到推理機, 开始部署模型")
 | 
			
		||||
 | 
			
		||||
		var server Server
 | 
			
		||||
		if err := configs.ORMDB().Where("models LIKE ?", "%"+model.Name+"%").Take(&server).Error; err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			// 如果没有则寻找空闲服务器
 | 
			
		||||
			// 如果没有空闲则创建新服务器
 | 
			
		||||
			// 取一台空闲的推理机上传并切换到此模型
 | 
			
		||||
			// 新建一台推理机上传并切换到此模型
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 执行生成任务
 | 
			
		||||
		if model.Image == "" {
 | 
			
		||||
			var data = struct {
 | 
			
		||||
				EnableHr                          bool              `json:"enable_hr"`
 | 
			
		||||
				DenoisingStrength                 int               `json:"denoising_strength"`
 | 
			
		||||
				FirstphaseWidth                   int               `json:"firstphase_width"`
 | 
			
		||||
				FirstphaseHeight                  int               `json:"firstphase_height"`
 | 
			
		||||
				HrScale                           int               `json:"hr_scale"`
 | 
			
		||||
				HrUpscaler                        string            `json:"hr_upscaler"`
 | 
			
		||||
				HrSecondPassSteps                 int               `json:"hr_second_pass_steps"`
 | 
			
		||||
				HrResizeX                         int               `json:"hr_resize_x"`
 | 
			
		||||
				HrResizeY                         int               `json:"hr_resize_y"`
 | 
			
		||||
				HrSamplerName                     string            `json:"hr_sampler_name"`
 | 
			
		||||
				HrPrompt                          string            `json:"hr_prompt"`
 | 
			
		||||
				HrNegativePrompt                  string            `json:"hr_negative_prompt"`
 | 
			
		||||
				Prompt                            string            `json:"prompt"`
 | 
			
		||||
				Styles                            []string          `json:"styles"`
 | 
			
		||||
				Seed                              int               `json:"seed"`
 | 
			
		||||
				Subseed                           int               `json:"subseed"`
 | 
			
		||||
				SubseedStrength                   int               `json:"subseed_strength"`
 | 
			
		||||
				SeedResizeFromH                   int               `json:"seed_resize_from_h"`
 | 
			
		||||
				SeedResizeFromW                   int               `json:"seed_resize_from_w"`
 | 
			
		||||
				SamplerName                       string            `json:"sampler_name"`
 | 
			
		||||
				BatchSize                         int               `json:"batch_size"`
 | 
			
		||||
				NIter                             int               `json:"n_iter"`
 | 
			
		||||
				Steps                             int               `json:"steps"`
 | 
			
		||||
				CfgScale                          int               `json:"cfg_scale"`
 | 
			
		||||
				Width                             int               `json:"width"`
 | 
			
		||||
				Height                            int               `json:"height"`
 | 
			
		||||
				RestoreFaces                      bool              `json:"restore_faces"`
 | 
			
		||||
				Tiling                            bool              `json:"tiling"`
 | 
			
		||||
				DoNotSaveSamples                  bool              `json:"do_not_save_samples"`
 | 
			
		||||
				DoNotSaveGrid                     bool              `json:"do_not_save_grid"`
 | 
			
		||||
				NegativePrompt                    string            `json:"negative_prompt"`
 | 
			
		||||
				Eta                               int               `json:"eta"`
 | 
			
		||||
				SMinUncond                        int               `json:"s_min_uncond"`
 | 
			
		||||
				SChurn                            int               `json:"s_churn"`
 | 
			
		||||
				STmax                             int               `json:"s_tmax"`
 | 
			
		||||
				STmin                             int               `json:"s_tmin"`
 | 
			
		||||
				SNoise                            int               `json:"s_noise"`
 | 
			
		||||
				OverrideSettings                  map[string]string `json:"override_settings"`
 | 
			
		||||
				OverrideSettingsRestoreAfterwards bool              `json:"override_settings_restore_afterwards"`
 | 
			
		||||
				ScriptArgs                        []interface{}     `json:"script_args"`
 | 
			
		||||
				SamplerIndex                      string            `json:"sampler_index"`
 | 
			
		||||
				ScriptName                        string            `json:"script_name"`
 | 
			
		||||
				SendImages                        bool              `json:"send_images"`
 | 
			
		||||
				SaveImages                        bool              `json:"save_images"`
 | 
			
		||||
				AlwaysonScripts                   map[string]string `json:"alwayson_scripts"`
 | 
			
		||||
			}{
 | 
			
		||||
				EnableHr:                          false,
 | 
			
		||||
				DenoisingStrength:                 0,
 | 
			
		||||
				FirstphaseWidth:                   0,
 | 
			
		||||
				FirstphaseHeight:                  0,
 | 
			
		||||
				HrScale:                           2,
 | 
			
		||||
				HrUpscaler:                        "nearest",
 | 
			
		||||
				HrSecondPassSteps:                 0,
 | 
			
		||||
				HrResizeX:                         0,
 | 
			
		||||
				HrResizeY:                         0,
 | 
			
		||||
				HrSamplerName:                     "",
 | 
			
		||||
				HrPrompt:                          "",
 | 
			
		||||
				HrNegativePrompt:                  "",
 | 
			
		||||
				Prompt:                            "miao~",
 | 
			
		||||
				Styles:                            []string{},
 | 
			
		||||
				Seed:                              -1,
 | 
			
		||||
				Subseed:                           -1,
 | 
			
		||||
				SubseedStrength:                   0,
 | 
			
		||||
				SeedResizeFromH:                   -1,
 | 
			
		||||
				SeedResizeFromW:                   -1,
 | 
			
		||||
				SamplerName:                       "beamsearch",
 | 
			
		||||
				BatchSize:                         1,
 | 
			
		||||
				NIter:                             1,
 | 
			
		||||
				Steps:                             50,
 | 
			
		||||
				CfgScale:                          7,
 | 
			
		||||
				Width:                             512,
 | 
			
		||||
				Height:                            512,
 | 
			
		||||
				RestoreFaces:                      false,
 | 
			
		||||
				Tiling:                            false,
 | 
			
		||||
				DoNotSaveSamples:                  false,
 | 
			
		||||
				DoNotSaveGrid:                     false,
 | 
			
		||||
				NegativePrompt:                    "",
 | 
			
		||||
				Eta:                               0,
 | 
			
		||||
				SMinUncond:                        0,
 | 
			
		||||
				SChurn:                            0,
 | 
			
		||||
				STmax:                             0,
 | 
			
		||||
				STmin:                             0,
 | 
			
		||||
				SNoise:                            1,
 | 
			
		||||
				OverrideSettings:                  map[string]string{},
 | 
			
		||||
				OverrideSettingsRestoreAfterwards: false,
 | 
			
		||||
				ScriptArgs:                        []interface{}{},
 | 
			
		||||
				SamplerIndex:                      "Euler",
 | 
			
		||||
				ScriptName:                        "generate",
 | 
			
		||||
				SendImages:                        true,
 | 
			
		||||
				SaveImages:                        false,
 | 
			
		||||
				AlwaysonScripts:                   map[string]string{},
 | 
			
		||||
			}
 | 
			
		||||
			// 接收到的图片列表
 | 
			
		||||
			var rest = struct {
 | 
			
		||||
				Images []string `json:"images"`
 | 
			
		||||
			}{
 | 
			
		||||
				Images: []string{},
 | 
			
		||||
			}
 | 
			
		||||
			var url = fmt.Sprintf("http://%s:%d/sdapi/v1/txt2img", server.IP, server.Port)
 | 
			
		||||
			if err := goreq.Post(url).SetJsonBody(data).Do().BindJSON(&rest); err != nil {
 | 
			
		||||
				log.Println("API 查询失败:", err)
 | 
			
		||||
			}
 | 
			
		||||
			log.Println("API 查询成功:", rest)
 | 
			
		||||
			for _, img := range rest.Images {
 | 
			
		||||
				log.Println("保存图片:", img)
 | 
			
		||||
				// 将base64编码的图片保存到本地webp
 | 
			
		||||
				if err := SaveBase64Image(img, "data/images/"+img+".webp"); err != nil {
 | 
			
		||||
					log.Println(err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 将base64编码的图片保存到本地webp
 | 
			
		||||
func SaveBase64Image(base64Str string, filename string) error {
 | 
			
		||||
	// 解码base64图片
 | 
			
		||||
	data, err := base64.StdEncoding.DecodeString(base64Str)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 将png图片解码为image.Image
 | 
			
		||||
	img, err := png.Decode(bytes.NewReader(data))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 创建webp编码器
 | 
			
		||||
	webpWriter, err := os.Create(filename)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer webpWriter.Close()
 | 
			
		||||
 | 
			
		||||
	// 将image.Image编码为webp格式并保存到本地
 | 
			
		||||
	if err := webp.Encode(webpWriter, img, &webp.Options{Lossless: true}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (model *Model) Train() (err error) {
 | 
			
		||||
 
 | 
			
		||||
@@ -3,59 +3,44 @@ package models
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type WebSocketManager struct {
 | 
			
		||||
	connections map[string]*websocket.Conn
 | 
			
		||||
	listeners   map[string]map[chan struct{}]struct{}
 | 
			
		||||
	connections map[*websocket.Conn]string // 连接指针:任务ID
 | 
			
		||||
	mutex       sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 创建一个新的连接池
 | 
			
		||||
func NewWebSocketManager() *WebSocketManager {
 | 
			
		||||
	return &WebSocketManager{
 | 
			
		||||
		connections: make(map[string]*websocket.Conn),
 | 
			
		||||
		connections: make(map[*websocket.Conn]string),
 | 
			
		||||
		mutex:       sync.RWMutex{},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mgr *WebSocketManager) AddConnection(conn *websocket.Conn) string {
 | 
			
		||||
// 向连接池加入一个新连接
 | 
			
		||||
func (mgr *WebSocketManager) AddConnection(conn *websocket.Conn, task string) {
 | 
			
		||||
	mgr.mutex.Lock()
 | 
			
		||||
	defer mgr.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	id := uuid.New().String() // 为每个连接生成一个唯一的 ID
 | 
			
		||||
	mgr.connections[id] = conn
 | 
			
		||||
 | 
			
		||||
	return id
 | 
			
		||||
	mgr.connections[conn] = task
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mgr *WebSocketManager) RemoveConnection(id string) {
 | 
			
		||||
// 从连接池中移除一个连接
 | 
			
		||||
func (mgr *WebSocketManager) RemoveConnection(conn *websocket.Conn) {
 | 
			
		||||
	mgr.mutex.Lock()
 | 
			
		||||
	defer mgr.mutex.Unlock()
 | 
			
		||||
	delete(mgr.connections, id)
 | 
			
		||||
	delete(mgr.connections, conn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mgr *WebSocketManager) ListenForChanges(target string, callback func()) {
 | 
			
		||||
	notifications := make(chan struct{})
 | 
			
		||||
// 任务状态变化时, 向监听此任务的所有连接发送消息
 | 
			
		||||
func (mgr *WebSocketManager) NotifyTaskChange(task string, data interface{}) {
 | 
			
		||||
	mgr.mutex.Lock()
 | 
			
		||||
	defer mgr.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	if _, ok := mgr.listeners[target]; !ok {
 | 
			
		||||
		mgr.listeners[target] = make(map[chan struct{}]struct{})
 | 
			
		||||
	}
 | 
			
		||||
	mgr.listeners[target][notifications] = struct{}{}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			callback()
 | 
			
		||||
			for listener := range mgr.listeners[target] {
 | 
			
		||||
				select {
 | 
			
		||||
				case listener <- struct{}{}:
 | 
			
		||||
				default:
 | 
			
		||||
					delete(mgr.listeners[target], listener)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
	for conn, value := range mgr.connections {
 | 
			
		||||
		if value == task {
 | 
			
		||||
			conn.WriteJSON(data)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package models
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"main/configs"
 | 
			
		||||
@@ -8,23 +9,40 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Server struct {
 | 
			
		||||
	ID        string                   `json:"id" gorm:"primary_key"`
 | 
			
		||||
	Name      string                   `json:"name"`
 | 
			
		||||
	Type      string                   `json:"type"` // (訓練|推理)
 | 
			
		||||
	IP        string                   `json:"ip"`
 | 
			
		||||
	Port      int                      `json:"port"`
 | 
			
		||||
	Status    string                   `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
 | 
			
		||||
	UserName  string                   `json:"username"`
 | 
			
		||||
	Password  string                   `json:"password"`
 | 
			
		||||
	Models    []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
 | 
			
		||||
	CreatedAt time.Time                `json:"created_at" gorm:"autoCreateTime"`
 | 
			
		||||
	UpdatedAt time.Time                `json:"updated_at" gorm:"autoUpdateTime"`
 | 
			
		||||
type ModelList []string
 | 
			
		||||
 | 
			
		||||
func (list *ModelList) Scan(value interface{}) error {
 | 
			
		||||
	return json.Unmarshal(value.([]byte), list)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (list ModelList) Value() (driver.Value, error) {
 | 
			
		||||
	return json.Marshal(list)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Server struct {
 | 
			
		||||
	ID        string    `json:"id" gorm:"primary_key"`
 | 
			
		||||
	Name      string    `json:"name"`
 | 
			
		||||
	Type      string    `json:"type"` // (训练|推理)
 | 
			
		||||
	IP        string    `json:"ip"`
 | 
			
		||||
	Port      int       `json:"port"`
 | 
			
		||||
	Status    string    `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
 | 
			
		||||
	UserName  string    `json:"username"`
 | 
			
		||||
	Password  string    `json:"password"`
 | 
			
		||||
	Models    ModelList `json:"models"`
 | 
			
		||||
	CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
			
		||||
	UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取所有服务器
 | 
			
		||||
func GetServers() (servers []Server, err error) {
 | 
			
		||||
	err = configs.ORMDB().Find(&servers).Error
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 檢查服務器是否正常
 | 
			
		||||
func (server *Server) CheckStatus() error {
 | 
			
		||||
	switch server.Type {
 | 
			
		||||
	case "訓練":
 | 
			
		||||
	case "训练":
 | 
			
		||||
		resp, err := http.Get(fmt.Sprintf("http://%s:%d/dreambooth/status", server.IP, server.Port))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			server.Status = "異常"
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user