diff --git a/models/Image.go b/models/Image.go index cfb4332..0efadd8 100644 --- a/models/Image.go +++ b/models/Image.go @@ -30,9 +30,13 @@ type Image struct { NumInferenceSteps int `json:"num_inference_steps"` // Number of inference steps (minimum: 1; maximum: 500) GuidanceScale float32 `json:"guidance_scale"` // Scale for classifier-free guidance (minimum: 1; maximum: 20) Scheduler string `json:"scheduler"` // (DDIM|K_EULER|DPMSolverMultistep|K_EULER_ANCESTRAL|PNDM|KLMS) - Seed int `json:"seed"` // Random seed (minimum: 0; maximum: 2147483647) - FromImage string `json:"from_image"` // Image to start from - UserID int `json:"user_id"` + Seed int `json:"seed"` // 随机种子(minimum: 0; maximum: 2147483647) + FromImage int `json:"from_image"` // 来源图片(如果是从图片生成的, 则记录来源图片的ID) + Task string `json:"task"` // 任务编号(uuid) + Status string `json:"status"` // 任务状态(queued|running|finished|failed) + Progress int `json:"progress"` // 任务进度(0-100) + Public bool `json:"public"` // 是否公开 + UserID int `json:"user_id"` // 用户ID CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } diff --git a/routers/images.go b/routers/images.go index 341dcbd..111b6f4 100644 --- a/routers/images.go +++ b/routers/images.go @@ -17,6 +17,7 @@ import ( "net/http" "os" + "github.com/google/uuid" "github.com/gorilla/mux" ) @@ -40,6 +41,61 @@ func ImagesGet(w http.ResponseWriter, r *http.Request) { func ImagesPost(w http.ResponseWriter, r *http.Request) { models.AccountRead(w, r, func(account *models.Account) { + + // 通过模型推理生成图像, 为图像标记任务批次 + if r.Header.Get("Content-Type") == "application/json" || r.Header.Get("Content-Type") == "application/json; charset=utf-8" { + + // 接收模板参数 + template := &struct { + FromImage int `json:"from_image"` // 来源图片(图生图时使用) + Prompt string `json:"prompt"` // 提示词 + NegativePrompt string `json:"negative_prompt"` // 负面提示词 + NumInferenceSteps int `json:"num_inference_steps"` // 推理步数 + GuidanceScale float32 `json:"guidance_scale"` // 引导比例 + Scheduler string `json:"scheduler"` // 调度器 + Seed int `json:"seed"` // 随机种子(单张图生成时使用) + Number int `json:"number"` // 生成数量 + }{} + body, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Println(err) + return + } + defer r.Body.Close() + if err = json.Unmarshal(body, &template); err != nil { + log.Println(err) + return + } + + // 直接创建一组图片 + task := uuid.New().String() + var image_list []models.Image + for i := 0; i < template.Number; i++ { + var image models.Image + image.UserID = account.ID + image.Task = task + image.Status = "queued" + image.FromImage = template.FromImage + image.Prompt = template.Prompt + image.NegativePrompt = template.NegativePrompt + image.NumInferenceSteps = template.NumInferenceSteps + image.GuidanceScale = template.GuidanceScale + image.Scheduler = template.Scheduler + image.Seed = template.Seed + image_list = append(image_list, image) + } + + // 存储图片信息到数据库 + if err := configs.ORMDB().Create(&image_list).Error; err != nil { + log.Println(err) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(image_list) + return + } + // 接收上傳的圖片文件, 僅限一張 file, file_header, err := r.FormFile("file") if err != nil {