From b656065b43b0ec49adae0da62e8e8d98e730e38d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A1=9C=E8=8F=AF?= Date: Tue, 20 Jun 2023 18:09:22 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=9A=E8=BF=87list=E5=BB=BA=E7=AB=8Bwebsock?= =?UTF-8?q?et?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/Image.go | 42 +++++++++++++++++++++--------------------- routers/images.go | 43 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 22 deletions(-) diff --git a/models/Image.go b/models/Image.go index 0efadd8..355a40b 100644 --- a/models/Image.go +++ b/models/Image.go @@ -16,27 +16,27 @@ import ( ) type Image struct { - ID int `json:"id" gorm:"primary_key"` - Name string `json:"name"` - Hash string `json:"hash"` - Path string `json:"path"` - Type string `json:"type"` - Size int `json:"size"` - Width int `json:"width"` - Height int `json:"height"` - Prompt string `json:"prompt"` - Format string `json:"format"` - NegativePrompt string `json:"negative_prompt"` - NumInferenceSteps int `json:"num_inference_steps"` // Number of inference steps (minimum: 1; maximum: 500) - GuidanceScale float32 `json:"guidance_scale"` // Scale for classifier-free guidance (minimum: 1; maximum: 20) - Scheduler string `json:"scheduler"` // (DDIM|K_EULER|DPMSolverMultistep|K_EULER_ANCESTRAL|PNDM|KLMS) - Seed int `json:"seed"` // 随机种子(minimum: 0; maximum: 2147483647) - FromImage int `json:"from_image"` // 来源图片(如果是从图片生成的, 则记录来源图片的ID) - Task string `json:"task"` // 任务编号(uuid) - Status string `json:"status"` // 任务状态(queued|running|finished|failed) - Progress int `json:"progress"` // 任务进度(0-100) - Public bool `json:"public"` // 是否公开 - UserID int `json:"user_id"` // 用户ID + ID int `json:"id" gorm:"primary_key"` // ID + Name string `json:"name"` // 名称 + Hash string `json:"hash"` // 哈希值 + Path string `json:"path"` // 路径 + Type string `json:"type"` // 类型 + Size int `json:"size"` // 大小 + Width int `json:"width"` // 宽度 + Height int `json:"height"` // 高度 + Format string `json:"format"` // 格式 + Prompt string `json:"prompt"` // 提示词 + NegativePrompt string `json:"negative_prompt"` // 负向提示 + NumInferenceSteps int `json:"num_inference_steps"` // 推理步数(minimum: 1; maximum: 500) + GuidanceScale float32 `json:"guidance_scale"` // 引导比例(minimum: 1; maximum: 20) + Scheduler string `json:"scheduler"` // 调度器(DDIM|K_EULER|DPMSolverMultistep|K_EULER_ANCESTRAL|PNDM|KLMS) + Seed int `json:"seed"` // 随机种子(minimum: 0; maximum: 2147483647) + FromImage int `json:"from_image"` // 来源图片(如果是从图片生成的, 则记录来源图片的ID) + Task string `json:"task"` // 任务编号(uuid) + Status string `json:"status"` // 任务状态(queued|running|finished|failed) + Progress int `json:"progress"` // 任务进度(0-100) + Public bool `json:"public"` // 是否公开 + UserID int `json:"user_id"` // 用户ID CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } diff --git a/routers/images.go b/routers/images.go index 111b6f4..a9d6c1c 100644 --- a/routers/images.go +++ b/routers/images.go @@ -19,9 +19,48 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" + "github.com/gorilla/websocket" ) +var images_websocket_manager = models.NewWebSocketManager() + func ImagesGet(w http.ResponseWriter, r *http.Request) { + + // websocket 推理图像 + if r.Header.Get("Upgrade") == "websocket" { + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + defer conn.Close() + + task := r.URL.Query().Get("task") + if task == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("task 参数不能为空")) + return + } + + wsid := images_websocket_manager.AddConnection(conn) + defer images_websocket_manager.RemoveConnection(wsid) + + for { + _, msg, err := conn.ReadMessage() + if err != nil { + log.Println(err) + return + } + log.Println(string(msg)) + if string(msg) == "close" { + break + } + } + return + + } + var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) @@ -67,8 +106,10 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) { return } - // 直接创建一组图片 + // TODO: 创建任务获得任务编号, 多张图时期望可以流式推理 task := uuid.New().String() + + // 直接创建一组图片 var image_list []models.Image for i := 0; i < template.Number; i++ { var image models.Image