输入检查

This commit is contained in:
2023-06-20 19:01:18 +08:00
parent b656065b43
commit 0504a4643c

View File

@@ -8,6 +8,7 @@ import (
_ "image/gif" _ "image/gif"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
"regexp"
"io/ioutil" "io/ioutil"
"log" "log"
@@ -82,9 +83,7 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
models.AccountRead(w, r, func(account *models.Account) { models.AccountRead(w, r, func(account *models.Account) {
// 通过模型推理生成图像, 为图像标记任务批次 // 通过模型推理生成图像, 为图像标记任务批次
if r.Header.Get("Content-Type") == "application/json" || r.Header.Get("Content-Type") == "application/json; charset=utf-8" { if match, _ := regexp.MatchString("application/json", r.Header.Get("Content-Type")); match {
// 接收模板参数
template := &struct { template := &struct {
FromImage int `json:"from_image"` // 来源图片(图生图时使用) FromImage int `json:"from_image"` // 来源图片(图生图时使用)
Prompt string `json:"prompt"` // 提示词 Prompt string `json:"prompt"` // 提示词
@@ -92,7 +91,7 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
NumInferenceSteps int `json:"num_inference_steps"` // 推理步数 NumInferenceSteps int `json:"num_inference_steps"` // 推理步数
GuidanceScale float32 `json:"guidance_scale"` // 引导比例 GuidanceScale float32 `json:"guidance_scale"` // 引导比例
Scheduler string `json:"scheduler"` // 调度器 Scheduler string `json:"scheduler"` // 调度器
Seed int `json:"seed"` // 随机种子(单张图生成时使用) Seed string `json:"seed"` // 随机种子(单张图生成时使用)
Number int `json:"number"` // 生成数量 Number int `json:"number"` // 生成数量
}{} }{}
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
@@ -106,6 +105,20 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
return return
} }
// 输入检查
if template.Number <= 0 {
template.Number = 1
}
if template.NumInferenceSteps <= 0 {
template.NumInferenceSteps = 20
}
if template.GuidanceScale <= 0 {
template.GuidanceScale = 1
}
if template.GuidanceScale > 20 {
template.GuidanceScale = 20
}
// TODO: 创建任务获得任务编号, 多张图时期望可以流式推理 // TODO: 创建任务获得任务编号, 多张图时期望可以流式推理
task := uuid.New().String() task := uuid.New().String()
@@ -134,6 +147,7 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(image_list) json.NewEncoder(w).Encode(image_list)
//w.Write(utils.ToJSON({"task": task, "list": image_list}))
return return
} }