提示词引导系数 (CFG Scale)

This commit is contained in:
2023-06-23 02:06:05 +08:00
parent 546ff57d6a
commit 8e874f7d0f
3 changed files with 21 additions and 21 deletions

View File

@@ -28,7 +28,7 @@ type Image struct {
Prompt string `json:"prompt"` // 提示词 Prompt string `json:"prompt"` // 提示词
NegativePrompt string `json:"negative_prompt"` // 负向提示 NegativePrompt string `json:"negative_prompt"` // 负向提示
NumInferenceSteps int `json:"num_inference_steps"` // 推理步数(minimum: 1; maximum: 500) NumInferenceSteps int `json:"num_inference_steps"` // 推理步数(minimum: 1; maximum: 500)
GuidanceScale float32 `json:"guidance_scale"` // 引导比例(minimum: 1; maximum: 20) CfgScale int `json:"cfg_scale"` // 引导比例(minimum: 1; maximum: 20)
Scheduler string `json:"scheduler"` // 调度器(DDIM|K_EULER|DPMSolverMultistep|K_EULER_ANCESTRAL|PNDM|KLMS) Scheduler string `json:"scheduler"` // 调度器(DDIM|K_EULER|DPMSolverMultistep|K_EULER_ANCESTRAL|PNDM|KLMS)
Seed string `json:"seed"` // 随机种子(minimum: 0; maximum: 2147483647) Seed string `json:"seed"` // 随机种子(minimum: 0; maximum: 2147483647)
FromImage int `json:"from_image"` // 来源图片(如果是从图片生成的, 则记录来源图片的ID) FromImage int `json:"from_image"` // 来源图片(如果是从图片生成的, 则记录来源图片的ID)

View File

@@ -108,9 +108,9 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) {
//SeedResizeFromW int `json:"seed_resize_from_w"` //SeedResizeFromW int `json:"seed_resize_from_w"`
//SamplerName string `json:"sampler_name"` //SamplerName string `json:"sampler_name"`
//BatchSize int `json:"batch_size"` //BatchSize int `json:"batch_size"`
NIter int `json:"n_iter"` NIter int `json:"n_iter"`
Steps int `json:"steps"` Steps int `json:"steps"`
//CfgScale int `json:"cfg_scale"` CfgScale int `json:"cfg_scale"`
//Width int `json:"width"` //Width int `json:"width"`
//Height int `json:"height"` //Height int `json:"height"`
//RestoreFaces bool `json:"restore_faces"` //RestoreFaces bool `json:"restore_faces"`
@@ -154,9 +154,9 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) {
//SeedResizeFromW: -1, //SeedResizeFromW: -1,
//SamplerName: "beamsearch", //SamplerName: "beamsearch",
//BatchSize: 1, //BatchSize: 1,
NIter: len(image_list), NIter: len(image_list), // 1~100
Steps: 50, Steps: 50, // 1~150
//CfgScale: 7, CfgScale: image_list[0].CfgScale,
//Width: 512, //Width: 512,
//Height: 512, //Height: 512,
//RestoreFaces: false, //RestoreFaces: false,

View File

@@ -103,15 +103,15 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
// 通过模型推理生成图像, 为图像标记任务批次 // 通过模型推理生成图像, 为图像标记任务批次
if match, _ := regexp.MatchString("application/json", r.Header.Get("Content-Type")); match { if match, _ := regexp.MatchString("application/json", r.Header.Get("Content-Type")); match {
template := &struct { template := &struct {
FromImage int `json:"from_image"` // 来源图片(图生图时使用) FromImage int `json:"from_image"` // 来源图片(图生图时使用)
Prompt string `json:"prompt"` // 提示词 Prompt string `json:"prompt"` // 提示词
NegativePrompt string `json:"negative_prompt"` // 负面提示词 NegativePrompt string `json:"negative_prompt"` // 负面提示词
Steps int `json:"steps"` // 推理步数 Steps int `json:"steps"` // 推理步数
GuidanceScale float32 `json:"guidance_scale"` // 引导比例 CfgScale int `json:"cfg_scale"` // 引导比例
Scheduler string `json:"scheduler"` // 调度器 Scheduler string `json:"scheduler"` // 调度器
Seed string `json:"seed"` // 随机种子(单张图生成时使用) Seed string `json:"seed"` // 随机种子(单张图生成时使用)
NIter int `json:"n_iter"` // 生成数量 NIter int `json:"n_iter"` // 生成数量
ModelID int `json:"model_id"` // 模型ID ModelID int `json:"model_id"` // 模型ID
}{} }{}
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
@@ -131,11 +131,11 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
if template.Steps <= 0 { if template.Steps <= 0 {
template.Steps = 50 template.Steps = 50
} }
if template.GuidanceScale <= 0 { if template.CfgScale <= 0 {
template.GuidanceScale = 1 template.CfgScale = 1
} }
if template.GuidanceScale > 20 { if template.CfgScale > 20 {
template.GuidanceScale = 20 template.CfgScale = 20
} }
if template.Scheduler == "" { if template.Scheduler == "" {
template.Scheduler = "DDIM" template.Scheduler = "DDIM"
@@ -166,7 +166,7 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
image.Prompt = template.Prompt image.Prompt = template.Prompt
image.NegativePrompt = template.NegativePrompt image.NegativePrompt = template.NegativePrompt
image.NumInferenceSteps = template.Steps image.NumInferenceSteps = template.Steps
image.GuidanceScale = template.GuidanceScale image.CfgScale = template.CfgScale
image.Scheduler = template.Scheduler image.Scheduler = template.Scheduler
image.Seed = template.Seed image.Seed = template.Seed
image_list = append(image_list, image) image_list = append(image_list, image)