进度条
This commit is contained in:
@@ -246,7 +246,7 @@ def main_dev(opt):
|
|||||||
# 如果模型不同,重新加载模型(注意释放内存)
|
# 如果模型不同,重新加载模型(注意释放内存)
|
||||||
if task['ckpt'] != model_name:
|
if task['ckpt'] != model_name:
|
||||||
# 修改状态为加载模型
|
# 修改状态为加载模型
|
||||||
update_task_status(task, "loading", 0)
|
update_task_status(task, "init", 0)
|
||||||
# 获取环境配置
|
# 获取环境配置
|
||||||
model_name = task['ckpt']
|
model_name = task['ckpt']
|
||||||
opt.config = f'/data/{model_name}.yaml'
|
opt.config = f'/data/{model_name}.yaml'
|
||||||
@@ -323,6 +323,7 @@ def main_dev(opt):
|
|||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||||
|
update_task_status(task=task, status='diffusing', progress=0.5) # 修改任务状态
|
||||||
samples, _ = sampler.sample(S=opt.steps,
|
samples, _ = sampler.sample(S=opt.steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
batch_size=task['number'],
|
batch_size=task['number'],
|
||||||
@@ -332,6 +333,7 @@ def main_dev(opt):
|
|||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
eta=opt.ddim_eta,
|
eta=opt.ddim_eta,
|
||||||
x_T=start_code)
|
x_T=start_code)
|
||||||
|
update_task_status(task=task, status='build', progress=0.8) # 修改任务状态
|
||||||
x_samples = model.decode_first_stage(samples)
|
x_samples = model.decode_first_stage(samples)
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
for x_sample in x_samples:
|
for x_sample in x_samples:
|
||||||
|
Reference in New Issue
Block a user