指定执行次数, 取消宫格图

This commit is contained in:
2023-02-18 21:13:41 +08:00
parent 607d958a9d
commit a751623e4d

View File

@@ -355,11 +355,11 @@ def main_dev(opt):
for _ in range(3): for _ in range(3):
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext
with torch.no_grad(), \ with torch.no_grad(), precision_scope(opt.device), model.ema_scope():
precision_scope(opt.device), \
model.ema_scope():
all_samples = list() all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"): # 执行指定的次数
for n in trange(item['number'], desc="Sampling"):
print("Sampling:", n)
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):
uc = None uc = None
if opt.scale != 1.0: if opt.scale != 1.0:
@@ -387,16 +387,45 @@ def main_dev(opt):
base_count += 1 base_count += 1
sample_count += 1 sample_count += 1
all_samples.append(x_samples) all_samples.append(x_samples)
# additionally, save as grid print("Sample count:", sample_count)
grid = torch.stack(all_samples, 0) #for n in trange(opt.n_iter, desc="Sampling"):
grid = rearrange(grid, 'n b c h w -> (n b) c h w') # for prompts in tqdm(data, desc="data"):
grid = make_grid(grid, nrow=n_rows) # uc = None
# to image # if opt.scale != 1.0:
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() # uc = model.get_learned_conditioning(batch_size * [""])
grid = Image.fromarray(grid.astype(np.uint8)) # if isinstance(prompts, tuple):
grid = put_watermark(grid, wm_encoder) # prompts = list(prompts)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) # c = model.get_learned_conditioning(prompts)
grid_count += 1 # shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
# samples, _ = sampler.sample(S=opt.steps,
# conditioning=c,
# batch_size=opt.n_samples,
# shape=shape,
# verbose=False,
# unconditional_guidance_scale=opt.scale,
# unconditional_conditioning=uc,
# eta=opt.ddim_eta,
# x_T=start_code)
# x_samples = model.decode_first_stage(samples)
# x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
# for x_sample in x_samples:
# x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
# img = Image.fromarray(x_sample.astype(np.uint8))
# img = put_watermark(img, wm_encoder)
# img.save(os.path.join(sample_path, f"{base_count:05}.png"))
# base_count += 1
# sample_count += 1
# all_samples.append(x_samples)
## additionally, save as grid
#grid = torch.stack(all_samples, 0)
#grid = rearrange(grid, 'n b c h w -> (n b) c h w')
#grid = make_grid(grid, nrow=n_rows)
## to image
#grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
#grid = Image.fromarray(grid.astype(np.uint8))
#grid = put_watermark(grid, wm_encoder)
#grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
#grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \n", f" \nEnjoy.") print(f"Your samples are ready and waiting for you here: \n{outpath} \n", f" \nEnjoy.")
# 修改任务状态为完成 # 修改任务状态为完成
update_task_status(task=item, status='done', progress=1) update_task_status(task=item, status='done', progress=1)