This commit is contained in:
2023-02-18 21:50:01 +08:00
parent d64260d5b0
commit f92d9e9de6

View File

@@ -240,7 +240,7 @@ def main_dev(opt):
if 'prompt' in item: if 'prompt' in item:
opt.prompt = item['prompt'] # 描述 opt.prompt = item['prompt'] # 描述
if 'n_samples' in item: if 'n_samples' in item:
opt.n_samples = item['n_samples'] # 列数 opt.n_samples = item['number'] # 列数
if 'n_rows' in item: if 'n_rows' in item:
opt.n_rows = item['n_rows'] # 行数 opt.n_rows = item['n_rows'] # 行数
if 'scale' in item: if 'scale' in item:
@@ -345,10 +345,8 @@ def main_dev(opt):
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
with torch.no_grad(), additional_context: with torch.no_grad(), additional_context:
#for _ in range(3): for _ in range(3):
# c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
c = model.get_learned_conditioning(prompts)
print(c)
samples_ddim, _ = sampler.sample(S=5, samples_ddim, _ = sampler.sample(S=5,
conditioning=c, conditioning=c,
batch_size=batch_size, batch_size=batch_size,
@@ -364,8 +362,8 @@ def main_dev(opt):
precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext
with torch.no_grad(), precision_scope(opt.device), model.ema_scope(): with torch.no_grad(), precision_scope(opt.device), model.ema_scope():
#all_samples = list() #all_samples = list()
# 执行指定的次数 # 执行指定的任务组数 (row)(item['number'])
for n in trange(item['number'], desc="Sampling"): for n in trange(1, desc="Sampling"):
print("Sampling:", n) print("Sampling:", n)
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):
uc = None uc = None