DEBUG
This commit is contained in:
12
server.py
12
server.py
@@ -240,7 +240,7 @@ def main_dev(opt):
|
||||
if 'prompt' in item:
|
||||
opt.prompt = item['prompt'] # 描述
|
||||
if 'n_samples' in item:
|
||||
opt.n_samples = item['n_samples'] # 列数
|
||||
opt.n_samples = item['number'] # 列数
|
||||
if 'n_rows' in item:
|
||||
opt.n_rows = item['n_rows'] # 行数
|
||||
if 'scale' in item:
|
||||
@@ -345,10 +345,8 @@ def main_dev(opt):
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
with torch.no_grad(), additional_context:
|
||||
#for _ in range(3):
|
||||
# c = model.get_learned_conditioning(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
print(c)
|
||||
for _ in range(3):
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
samples_ddim, _ = sampler.sample(S=5,
|
||||
conditioning=c,
|
||||
batch_size=batch_size,
|
||||
@@ -364,8 +362,8 @@ def main_dev(opt):
|
||||
precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext
|
||||
with torch.no_grad(), precision_scope(opt.device), model.ema_scope():
|
||||
#all_samples = list()
|
||||
# 执行指定的次数
|
||||
for n in trange(item['number'], desc="Sampling"):
|
||||
# 执行指定的任务组数 (row)(item['number'])
|
||||
for n in trange(1, desc="Sampling"):
|
||||
print("Sampling:", n)
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
uc = None
|
||||
|
Reference in New Issue
Block a user