This commit is contained in:
2023-02-18 21:50:01 +08:00
parent d64260d5b0
commit f92d9e9de6

View File

@@ -240,7 +240,7 @@ def main_dev(opt):
if 'prompt' in item:
opt.prompt = item['prompt'] # 描述
if 'n_samples' in item:
opt.n_samples = item['n_samples'] # 列数
opt.n_samples = item['number'] # 列数
if 'n_rows' in item:
opt.n_rows = item['n_rows'] # 行数
if 'scale' in item:
@@ -345,10 +345,8 @@ def main_dev(opt):
if isinstance(prompts, tuple):
prompts = list(prompts)
with torch.no_grad(), additional_context:
#for _ in range(3):
# c = model.get_learned_conditioning(prompts)
for _ in range(3):
c = model.get_learned_conditioning(prompts)
print(c)
samples_ddim, _ = sampler.sample(S=5,
conditioning=c,
batch_size=batch_size,
@@ -364,8 +362,8 @@ def main_dev(opt):
precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext
with torch.no_grad(), precision_scope(opt.device), model.ema_scope():
#all_samples = list()
# 执行指定的次数
for n in trange(item['number'], desc="Sampling"):
# 执行指定的任务组数 (row)(item['number'])
for n in trange(1, desc="Sampling"):
print("Sampling:", n)
for prompts in tqdm(data, desc="data"):
uc = None