Files
ai/routers/images.go
2023-06-21 18:37:27 +08:00

299 lines
8.1 KiB
Go

package routers
import (
"crypto/md5"
"encoding/json"
"fmt"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"regexp"
"io/ioutil"
"log"
"main/configs"
"main/models"
"main/utils"
"net/http"
"os"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
var images_websocket_manager = models.NewWebSocketManager()
func ImagesGet(w http.ResponseWriter, r *http.Request) {
// websocket 推理图像
if r.Header.Get("Upgrade") == "websocket" {
upgrader := websocket.Upgrader{}
upgrader.CheckOrigin = func(r *http.Request) bool {
return true
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
defer conn.Close()
// 读取任务信息
task := r.URL.Query().Get("task")
if task == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("task 参数不能为空"))
return
}
// 从数据库中读取任务信息
var image_list []models.Image
if err := configs.ORMDB().Where("task = ?", task).Find(&image_list).Error; err != nil {
log.Println(err)
return
}
if len(image_list) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("任务不存在或已结束"))
return
}
log.Println("任务编号:", task, "任务数量:", len(image_list))
// 加入连接池
images_websocket_manager.AddConnection(conn, task)
defer images_websocket_manager.RemoveConnection(conn)
for {
_, msg, err := conn.ReadMessage()
if err != nil {
log.Println(err)
return
}
log.Println(string(msg))
if string(msg) == "close" {
break
}
}
return
}
var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
var image_list []models.Image
db := configs.ORMDB()
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&image_list)
for _, image := range image_list {
listview.List = append(listview.List, image)
}
db.Model(&models.Image{}).Count(&listview.Total)
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w)
}
func ImagesPost(w http.ResponseWriter, r *http.Request) {
models.AccountRead(w, r, func(account *models.Account) {
// 通过模型推理生成图像, 为图像标记任务批次
if match, _ := regexp.MatchString("application/json", r.Header.Get("Content-Type")); match {
template := &struct {
FromImage int `json:"from_image"` // 来源图片(图生图时使用)
Prompt string `json:"prompt"` // 提示词
NegativePrompt string `json:"negative_prompt"` // 负面提示词
NumInferenceSteps int `json:"num_inference_steps"` // 推理步数
GuidanceScale float32 `json:"guidance_scale"` // 引导比例
Scheduler string `json:"scheduler"` // 调度器
Seed string `json:"seed"` // 随机种子(单张图生成时使用)
Number int `json:"number"` // 生成数量
ModelID int `json:"model_id"` // 模型ID
}{}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &template); err != nil {
log.Println(err)
return
}
// 输入检查
if template.Number <= 0 {
template.Number = 1
}
if template.NumInferenceSteps <= 0 {
template.NumInferenceSteps = 20
}
if template.GuidanceScale <= 0 {
template.GuidanceScale = 1
}
if template.GuidanceScale > 20 {
template.GuidanceScale = 20
}
if template.Scheduler == "" {
template.Scheduler = "DDIM"
}
if template.ModelID <= 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("model_id 参数不能为空"))
return
}
// 从数据库中读取模型信息
var model models.Model = models.Model{ID: template.ModelID}
if err := configs.ORMDB().First(&model).Error; err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("模型不存在"))
return
}
// 直接创建一组图片
var image_list []models.Image
var task string = uuid.New().String()
for i := 0; i < template.Number; i++ {
var image models.Image
image.UserID = account.ID
image.Task = task
image.Status = "queued"
image.FromImage = template.FromImage
image.Prompt = template.Prompt
image.NegativePrompt = template.NegativePrompt
image.NumInferenceSteps = template.NumInferenceSteps
image.GuidanceScale = template.GuidanceScale
image.Scheduler = template.Scheduler
image.Seed = template.Seed
image_list = append(image_list, image)
}
// 推理图像
go model.Inference([]models.Image{}, func(img models.Image) {
log.Println("推理完成")
images_websocket_manager.NotifyTaskChange(task, img)
})
// 存储图片信息到数据库
if err := configs.ORMDB().Create(&image_list).Error; err != nil {
log.Println(err)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(image_list)
return
}
// 接收上傳的圖片文件, 僅限一張
file, file_header, err := r.FormFile("file")
if err != nil {
log.Println(err)
return
}
defer file.Close()
// 圖片寬高
imgData, format, err := image.Decode(file)
if err != nil {
log.Println(err)
return
}
fmt.Println(format, imgData.Bounds().Dx(), imgData.Bounds().Dy())
// 將文件指針移回開頭
if _, err := file.Seek(0, 0); err != nil {
log.Println(err)
return
}
// 读取文件内容
content, err := ioutil.ReadAll(file)
if err != nil {
log.Println(err)
return
}
// 整理文件信息
var img models.Image
img.Name = file_header.Filename
img.Size = int(file_header.Size) // 數據大小
img.Hash = fmt.Sprintf("%x", md5.Sum(content)) // 计算哈希
img.Type = file_header.Header.Get("Content-Type") // 文件類型
img.Path = fmt.Sprintf("data/images/%s.%s", img.Hash, format) // 存儲路徑
img.Width = imgData.Bounds().Dx() // 圖片寬度
img.Height = imgData.Bounds().Dy() // 圖片高度
img.Format = format // 圖片格式
img.UserID = account.ID // 用戶ID
// 先檢查 data/images 目錄是否存在
if _, err := ioutil.ReadDir("data/images"); err != nil {
if err := os.Mkdir("data/images", 0777); err != nil {
log.Println(err)
return
}
}
// 將文件存儲到本地 data/images 目錄下
if err := ioutil.WriteFile(img.Path, content, 0666); err != nil {
log.Println(err)
return
}
// 存儲圖片信息到數據庫
if err := configs.ORMDB().Create(&img).Error; err != nil {
log.Println(err)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(img))
})
}
func ImagesItemGet(w http.ResponseWriter, r *http.Request) {
image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().First(&image).Error; err != nil {
log.Println(err)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(image))
}
func ImagesItemPatch(w http.ResponseWriter, r *http.Request) {
image := models.Image{}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Println(err)
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &image); err != nil {
log.Println(err)
return
}
image.ID = utils.ParamInt(mux.Vars(r)["id"], 0)
if err := configs.ORMDB().Model(&image).Updates(image).Error; err != nil {
log.Println(err)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(image))
}
func ImagesItemDelete(w http.ResponseWriter, r *http.Request) {
image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Delete(&image).Error; err != nil {
log.Println(err)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(image))
}