diff --git a/server.py b/server.py index a1eda1d..2efe6ce 100644 --- a/server.py +++ b/server.py @@ -240,7 +240,7 @@ def main_dev(opt): if 'prompt' in item: opt.prompt = item['prompt'] # 描述 if 'n_samples' in item: - opt.n_samples = item['n_samples'] # 列数 + opt.n_samples = item['number'] # 列数 if 'n_rows' in item: opt.n_rows = item['n_rows'] # 行数 if 'scale' in item: @@ -345,10 +345,8 @@ def main_dev(opt): if isinstance(prompts, tuple): prompts = list(prompts) with torch.no_grad(), additional_context: - #for _ in range(3): - # c = model.get_learned_conditioning(prompts) - c = model.get_learned_conditioning(prompts) - print(c) + for _ in range(3): + c = model.get_learned_conditioning(prompts) samples_ddim, _ = sampler.sample(S=5, conditioning=c, batch_size=batch_size, @@ -364,8 +362,8 @@ def main_dev(opt): precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext with torch.no_grad(), precision_scope(opt.device), model.ema_scope(): #all_samples = list() - # 执行指定的次数 - for n in trange(item['number'], desc="Sampling"): + # 执行指定的任务组数 (row)(item['number']) + for n in trange(1, desc="Sampling"): print("Sampling:", n) for prompts in tqdm(data, desc="data"): uc = None