369 lines
10 KiB
Go
369 lines
10 KiB
Go
package routers
|
|
|
|
import (
|
|
"crypto/md5"
|
|
"encoding/json"
|
|
"fmt"
|
|
"image"
|
|
_ "image/gif"
|
|
_ "image/jpeg"
|
|
_ "image/png"
|
|
"regexp"
|
|
"strconv"
|
|
|
|
"io/ioutil"
|
|
"log"
|
|
"main/configs"
|
|
"main/models"
|
|
"main/utils"
|
|
"net/http"
|
|
"os"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/mux"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
var images_websocket_manager = models.NewWebSocketManager()
|
|
|
|
func ImagesGet(w http.ResponseWriter, r *http.Request) {
|
|
|
|
// websocket 推理图像
|
|
if r.Header.Get("Upgrade") == "websocket" {
|
|
upgrader := websocket.Upgrader{}
|
|
upgrader.CheckOrigin = func(r *http.Request) bool {
|
|
return true
|
|
}
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// 读取任务信息
|
|
task := r.URL.Query().Get("task")
|
|
if task == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("task 参数不能为空"))
|
|
return
|
|
}
|
|
// 从数据库中读取任务信息
|
|
var image_list []models.Image
|
|
if err := configs.ORMDB().Where("task = ?", task).Find(&image_list).Error; err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
if len(image_list) == 0 {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("任务不存在或已结束"))
|
|
return
|
|
}
|
|
|
|
log.Println("任务编号:", task, "任务数量:", len(image_list))
|
|
|
|
// 加入连接池
|
|
images_websocket_manager.AddConnection(conn, task)
|
|
defer images_websocket_manager.RemoveConnection(conn)
|
|
|
|
for {
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
log.Println(string(msg))
|
|
if string(msg) == "close" {
|
|
break
|
|
}
|
|
}
|
|
return
|
|
|
|
}
|
|
|
|
var listview models.ListView
|
|
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
|
|
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
|
|
|
|
var image_list []models.Image
|
|
db := configs.ORMDB()
|
|
if r.URL.Query().Get("task") != "" {
|
|
db = db.Where("task = ?", r.URL.Query().Get("task"))
|
|
}
|
|
if r.URL.Query().Get("user_id") != "" {
|
|
db = db.Where("user_id = ?", r.URL.Query().Get("user_id"))
|
|
}
|
|
if r.URL.Query().Get("status") != "" {
|
|
db = db.Where("status = ?", r.URL.Query().Get("status"))
|
|
}
|
|
if r.URL.Query().Get("from_image") != "" {
|
|
db = db.Where("from_image = ?", r.URL.Query().Get("from_image"))
|
|
}
|
|
if r.URL.Query().Get("prompt") != "" {
|
|
db = db.Where("prompt LIKE ?", "%"+r.URL.Query().Get("prompt")+"%")
|
|
}
|
|
if r.URL.Query().Get("negative_prompt") != "" {
|
|
db = db.Where("negative_prompt LIKE ?", "%"+r.URL.Query().Get("negative_prompt")+"%")
|
|
}
|
|
|
|
// 获取指定用户喜欢的图片
|
|
if r.URL.Query().Get("like") != "" {
|
|
list, err := models.LikeImage.GetA(r.URL.Query().Get("like"))
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
db = db.Where("id IN (?)", list)
|
|
}
|
|
|
|
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&image_list)
|
|
for _, image := range image_list {
|
|
listview.List = append(listview.List, image)
|
|
}
|
|
|
|
db.Model(&models.Image{}).Count(&listview.Total)
|
|
|
|
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
|
|
listview.WriteJSON(w)
|
|
}
|
|
|
|
func ImagesPost(w http.ResponseWriter, r *http.Request) {
|
|
models.AccountRead(w, r, func(account *models.Account) {
|
|
|
|
// 通过模型推理生成图像, 为图像标记任务批次
|
|
if match, _ := regexp.MatchString("application/json", r.Header.Get("Content-Type")); match {
|
|
template := &struct {
|
|
FromImage int `json:"from_image"` // 来源图片(图生图时使用)
|
|
Prompt string `json:"prompt"` // 提示词
|
|
NegativePrompt string `json:"negative_prompt"` // 负面提示词
|
|
Steps int `json:"steps"` // 迭代步数
|
|
CfgScale int `json:"cfg_scale"` // 提示词引导系数 (CFG Scale)
|
|
SamplerName string `json:"sampler_name"` // 采样器名称(Sampler Name)
|
|
Seed int `json:"seed"` // 随机种子(单张图生成时使用)
|
|
NIter int `json:"n_iter"` // 生成数量
|
|
ModelID int `json:"model_id"` // 模型ID
|
|
Width int `json:"width"` // 图片宽度
|
|
Height int `json:"height"` // 图片高度
|
|
}{}
|
|
body, err := ioutil.ReadAll(r.Body)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
if err = json.Unmarshal(body, &template); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 输入检查
|
|
if template.NIter <= 0 {
|
|
template.NIter = 1
|
|
}
|
|
if template.Steps <= 0 {
|
|
template.Steps = 50
|
|
}
|
|
if template.CfgScale <= 0 {
|
|
template.CfgScale = 1
|
|
}
|
|
if template.CfgScale > 20 {
|
|
template.CfgScale = 20
|
|
}
|
|
if template.ModelID <= 0 {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("model_id 参数不能为空"))
|
|
return
|
|
}
|
|
|
|
// 从数据库中读取模型信息
|
|
var model models.Model = models.Model{ID: template.ModelID}
|
|
if err := configs.ORMDB().First(&model).Error; err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("模型不存在"))
|
|
return
|
|
}
|
|
|
|
// 直接创建一组图片
|
|
var image_list []models.Image
|
|
var task string = uuid.New().String()
|
|
for i := 0; i < template.NIter; i++ {
|
|
var image models.Image
|
|
image.UserID = account.ID
|
|
image.Task = task
|
|
image.Status = "queued"
|
|
image.FromImage = template.FromImage
|
|
image.Prompt = template.Prompt
|
|
image.NegativePrompt = template.NegativePrompt
|
|
image.Steps = template.Steps
|
|
image.CfgScale = template.CfgScale
|
|
image.SamplerName = template.SamplerName
|
|
image.Seed = template.Seed
|
|
image.ModelID = template.ModelID
|
|
image.Width = template.Width
|
|
image.Height = template.Height
|
|
image_list = append(image_list, image)
|
|
}
|
|
|
|
// 推理图像
|
|
go model.Inference(image_list, func(img models.Image) {
|
|
log.Println("推理完成")
|
|
images_websocket_manager.NotifyTaskChange(task, img) // 通知 websocket
|
|
configs.ORMDB().Model(&img).Updates(img) // 更新到数据库
|
|
})
|
|
|
|
// 存储图片信息到数据库
|
|
if err := configs.ORMDB().Create(&image_list).Error; err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
json.NewEncoder(w).Encode(image_list)
|
|
return
|
|
}
|
|
|
|
// 接收上傳的圖片文件, 僅限一張
|
|
file, file_header, err := r.FormFile("file")
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
// 圖片寬高
|
|
imgData, format, err := image.Decode(file)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
fmt.Println(format, imgData.Bounds().Dx(), imgData.Bounds().Dy())
|
|
|
|
// 將文件指針移回開頭
|
|
if _, err := file.Seek(0, 0); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 读取文件内容
|
|
content, err := ioutil.ReadAll(file)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 整理文件信息
|
|
var img models.Image
|
|
img.Name = file_header.Filename
|
|
img.Size = int(file_header.Size) // 數據大小
|
|
img.Hash = fmt.Sprintf("%x", md5.Sum(content)) // 计算哈希
|
|
img.Type = file_header.Header.Get("Content-Type") // 文件類型
|
|
img.Path = fmt.Sprintf("data/images/%s.%s", img.Hash, format) // 存儲路徑
|
|
img.Width = imgData.Bounds().Dx() // 圖片寬度
|
|
img.Height = imgData.Bounds().Dy() // 圖片高度
|
|
img.Format = format // 圖片格式
|
|
img.UserID = account.ID // 用戶ID
|
|
|
|
// 先檢查 data/images 目錄是否存在
|
|
if _, err := ioutil.ReadDir("data/images"); err != nil {
|
|
if err := os.Mkdir("data/images", 0777); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// 將文件存儲到本地 data/images 目錄下
|
|
if err := ioutil.WriteFile(img.Path, content, 0666); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 存儲圖片信息到數據庫
|
|
if err := configs.ORMDB().Create(&img).Error; err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(img))
|
|
})
|
|
}
|
|
|
|
func ImagesItemGet(w http.ResponseWriter, r *http.Request) {
|
|
image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
if err := configs.ORMDB().First(&image).Error; err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(image))
|
|
}
|
|
|
|
func ImagesItemPatch(w http.ResponseWriter, r *http.Request) {
|
|
image := models.Image{}
|
|
body, err := ioutil.ReadAll(r.Body)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
if err = json.Unmarshal(body, &image); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
image.ID = utils.ParamInt(mux.Vars(r)["id"], 0)
|
|
if err := configs.ORMDB().Model(&image).Updates(image).Error; err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(image))
|
|
}
|
|
|
|
// 删除一条图片
|
|
func ImagesItemDelete(w http.ResponseWriter, r *http.Request) {
|
|
image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
if err := configs.ORMDB().First(&image).Error; err != nil {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
w.Write([]byte("图片不存在"))
|
|
return
|
|
}
|
|
if err := configs.ORMDB().Delete(&image).Error; err != nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
w.Write([]byte("删除失败"))
|
|
return
|
|
}
|
|
// 删除本地图像文件 image.Path
|
|
os.Remove(image.Path)
|
|
|
|
// 删除所有用户喜欢此图片的记录(双向解绑, A是user, B是image)
|
|
models.LikeImage.RemoveB(strconv.Itoa(image.ID))
|
|
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(utils.ToJSON(image))
|
|
}
|
|
|
|
// 添加一条喜欢
|
|
func ImagesItemLike(w http.ResponseWriter, r *http.Request) {
|
|
models.AccountRead(w, r, func(account *models.Account) {
|
|
// 先检查图片是否存在
|
|
image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
if err := configs.ORMDB().First(&image).Error; err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte("图片不存在"))
|
|
return
|
|
}
|
|
// 添加喜欢
|
|
models.LikeImage.Add(strconv.Itoa(account.ID), strconv.Itoa(image.ID))
|
|
w.Write([]byte("ok"))
|
|
})
|
|
}
|
|
|
|
// 移除一条喜欢
|
|
func ImagesItemUnlike(w http.ResponseWriter, r *http.Request) {
|
|
models.AccountRead(w, r, func(account *models.Account) {
|
|
models.LikeImage.Remove(strconv.Itoa(account.ID), mux.Vars(r)["id"])
|
|
w.Write([]byte("ok"))
|
|
})
|
|
}
|