package routers import ( "crypto/md5" "encoding/json" "fmt" "image" _ "image/gif" _ "image/jpeg" _ "image/png" "regexp" "io/ioutil" "log" "main/configs" "main/models" "main/utils" "net/http" "os" "github.com/google/uuid" "github.com/gorilla/mux" "github.com/gorilla/websocket" ) var images_websocket_manager = models.NewWebSocketManager() func ImagesGet(w http.ResponseWriter, r *http.Request) { // websocket 推理图像 if r.Header.Get("Upgrade") == "websocket" { upgrader := websocket.Upgrader{} upgrader.CheckOrigin = func(r *http.Request) bool { return true } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println(err) return } defer conn.Close() // 读取任务信息 task := r.URL.Query().Get("task") if task == "" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("task 参数不能为空")) return } // 从数据库中读取任务信息 var image_list []models.Image if err := configs.ORMDB().Where("task = ?", task).Find(&image_list).Error; err != nil { log.Println(err) return } if len(image_list) == 0 { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("任务不存在或已结束")) return } log.Println("任务编号:", task, "任务数量:", len(image_list)) // 加入连接池 images_websocket_manager.AddConnection(conn, task) defer images_websocket_manager.RemoveConnection(conn) for { _, msg, err := conn.ReadMessage() if err != nil { log.Println(err) return } log.Println(string(msg)) if string(msg) == "close" { break } } return } var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) var image_list []models.Image db := configs.ORMDB() db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&image_list) for _, image := range image_list { listview.List = append(listview.List, image) } db.Model(&models.Image{}).Count(&listview.Total) listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } func ImagesPost(w http.ResponseWriter, r *http.Request) { models.AccountRead(w, r, func(account *models.Account) { // 通过模型推理生成图像, 为图像标记任务批次 if match, _ := regexp.MatchString("application/json", r.Header.Get("Content-Type")); match { template := &struct { FromImage int `json:"from_image"` // 来源图片(图生图时使用) Prompt string `json:"prompt"` // 提示词 NegativePrompt string `json:"negative_prompt"` // 负面提示词 Steps int `json:"steps"` // 迭代步数 CfgScale int `json:"cfg_scale"` // 提示词引导系数 (CFG Scale) SamplerName string `json:"sampler_name"` // 采样器名称(Sampler Name) Seed int `json:"seed"` // 随机种子(单张图生成时使用) NIter int `json:"n_iter"` // 生成数量 ModelID int `json:"model_id"` // 模型ID }{} body, err := ioutil.ReadAll(r.Body) if err != nil { log.Println(err) return } defer r.Body.Close() if err = json.Unmarshal(body, &template); err != nil { log.Println(err) return } // 输入检查 if template.NIter <= 0 { template.NIter = 1 } if template.Steps <= 0 { template.Steps = 50 } if template.CfgScale <= 0 { template.CfgScale = 1 } if template.CfgScale > 20 { template.CfgScale = 20 } if template.ModelID <= 0 { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("model_id 参数不能为空")) return } // 从数据库中读取模型信息 var model models.Model = models.Model{ID: template.ModelID} if err := configs.ORMDB().First(&model).Error; err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("模型不存在")) return } // 直接创建一组图片 var image_list []models.Image var task string = uuid.New().String() for i := 0; i < template.NIter; i++ { var image models.Image image.UserID = account.ID image.Task = task image.Status = "queued" image.FromImage = template.FromImage image.Prompt = template.Prompt image.NegativePrompt = template.NegativePrompt image.Steps = template.Steps image.CfgScale = template.CfgScale image.SamplerName = template.SamplerName image.Seed = template.Seed image_list = append(image_list, image) } // 推理图像 go model.Inference(image_list, func(img models.Image) { log.Println("推理完成") images_websocket_manager.NotifyTaskChange(task, img) // 通知 websocket configs.ORMDB().Model(&img).Updates(img) // 更新到数据库 }) // 存储图片信息到数据库 if err := configs.ORMDB().Create(&image_list).Error; err != nil { log.Println(err) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") json.NewEncoder(w).Encode(image_list) return } // 接收上傳的圖片文件, 僅限一張 file, file_header, err := r.FormFile("file") if err != nil { log.Println(err) return } defer file.Close() // 圖片寬高 imgData, format, err := image.Decode(file) if err != nil { log.Println(err) return } fmt.Println(format, imgData.Bounds().Dx(), imgData.Bounds().Dy()) // 將文件指針移回開頭 if _, err := file.Seek(0, 0); err != nil { log.Println(err) return } // 读取文件内容 content, err := ioutil.ReadAll(file) if err != nil { log.Println(err) return } // 整理文件信息 var img models.Image img.Name = file_header.Filename img.Size = int(file_header.Size) // 數據大小 img.Hash = fmt.Sprintf("%x", md5.Sum(content)) // 计算哈希 img.Type = file_header.Header.Get("Content-Type") // 文件類型 img.Path = fmt.Sprintf("data/images/%s.%s", img.Hash, format) // 存儲路徑 img.Width = imgData.Bounds().Dx() // 圖片寬度 img.Height = imgData.Bounds().Dy() // 圖片高度 img.Format = format // 圖片格式 img.UserID = account.ID // 用戶ID // 先檢查 data/images 目錄是否存在 if _, err := ioutil.ReadDir("data/images"); err != nil { if err := os.Mkdir("data/images", 0777); err != nil { log.Println(err) return } } // 將文件存儲到本地 data/images 目錄下 if err := ioutil.WriteFile(img.Path, content, 0666); err != nil { log.Println(err) return } // 存儲圖片信息到數據庫 if err := configs.ORMDB().Create(&img).Error; err != nil { log.Println(err) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(img)) }) } func ImagesItemGet(w http.ResponseWriter, r *http.Request) { image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().First(&image).Error; err != nil { log.Println(err) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(image)) } func ImagesItemPatch(w http.ResponseWriter, r *http.Request) { image := models.Image{} body, err := ioutil.ReadAll(r.Body) if err != nil { log.Println(err) return } defer r.Body.Close() if err = json.Unmarshal(body, &image); err != nil { log.Println(err) return } image.ID = utils.ParamInt(mux.Vars(r)["id"], 0) if err := configs.ORMDB().Model(&image).Updates(image).Error; err != nil { log.Println(err) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(image)) } func ImagesItemDelete(w http.ResponseWriter, r *http.Request) { image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Delete(&image).Error; err != nil { log.Println(err) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(image)) }