进度条

This commit is contained in:
2023-02-19 05:48:35 +08:00
parent 4d6ca94eff
commit f714c0db32

View File

@@ -246,7 +246,7 @@ def main_dev(opt):
# 如果模型不同,重新加载模型(注意释放内存) # 如果模型不同,重新加载模型(注意释放内存)
if task['ckpt'] != model_name: if task['ckpt'] != model_name:
# 修改状态为加载模型 # 修改状态为加载模型
update_task_status(task, "loading", 0) update_task_status(task, "init", 0)
# 获取环境配置 # 获取环境配置
model_name = task['ckpt'] model_name = task['ckpt']
opt.config = f'/data/{model_name}.yaml' opt.config = f'/data/{model_name}.yaml'
@@ -323,6 +323,7 @@ def main_dev(opt):
prompts = list(prompts) prompts = list(prompts)
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f] shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
update_task_status(task=task, status='diffusing', progress=0.5) # 修改任务状态
samples, _ = sampler.sample(S=opt.steps, samples, _ = sampler.sample(S=opt.steps,
conditioning=c, conditioning=c,
batch_size=task['number'], batch_size=task['number'],
@@ -332,6 +333,7 @@ def main_dev(opt):
unconditional_conditioning=uc, unconditional_conditioning=uc,
eta=opt.ddim_eta, eta=opt.ddim_eta,
x_T=start_code) x_T=start_code)
update_task_status(task=task, status='build', progress=0.8) # 修改任务状态
x_samples = model.decode_first_stage(samples) x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples: for x_sample in x_samples: