This commit is contained in:
2023-02-19 05:06:23 +08:00
parent eb0c491577
commit 4d6ca94eff
2 changed files with 12 additions and 6 deletions

View File

@@ -229,7 +229,8 @@ def get_tasks(tasks:list=[]):
tasks = requests.get("http://localhost:3000/api/drawing?status=waiting").json()
if len(tasks) == 0: time.sleep(2)
except:
print("get tasks error")
# 打印当前时间
print("get tasks error", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
time.sleep(2)
return tasks
@@ -241,10 +242,11 @@ def main_dev(opt):
while True:
for task in get_tasks(): # 遍历 tasks 返回 dict
print('task:', task) # 打印任务
update_task_status(task, "running", 0) # 更新任务状态为运行中
# 如果模型不同,重新加载模型(注意释放内存)
if task['ckpt'] != model_name:
# 修改状态为加载模型
update_task_status(task, "loading", 0)
# 获取环境配置
model_name = task['ckpt']
opt.config = f'/data/{model_name}.yaml'
@@ -257,6 +259,9 @@ def main_dev(opt):
print(f"加载模型到显存: {model_name}..")
model = load_model_from_config(config, f"{opt.ckpt}", device)
print(f"加载到显存完成: {model_name}")
# 更新任务状态为运行中
update_task_status(task, "running", 0)
# 使用指定的模型和配置文件进行推理一组参数
if opt.plms:
@@ -300,13 +305,16 @@ def main_dev(opt):
start_code = None
if opt.fixed_code:
start_code = torch.randn([task['number'], opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
# 更新进度
update_task_status(task, "running", 0.1)
# 生成图片
precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext
with torch.no_grad(), precision_scope(opt.device), model.ema_scope():
images = []
# 执行指定的任务批次 (row)(task['number'])
for n in trange(1, desc="Sampling"):
print("Sampling:", data)
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
@@ -327,7 +335,6 @@ def main_dev(opt):
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples:
print("Sample count:", sample_count)
imge_path = os.path.join(sample_path, f"{base_count:05}.png")
imge_path = os.path.abspath(imge_path) # 转换为绝对路径
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
@@ -337,7 +344,6 @@ def main_dev(opt):
base_count += 1
sample_count += 1
images.append(imge_path)
print("Sample count:", sample_count)
update_task_status(task=task, status='done', progress=1, data=images) # 修改任务状态为完成
print("批次任务结束..")
#break

View File

@@ -8,7 +8,7 @@ export default defineEventHandler(async event => {
if (!fs.existsSync('outputs')) {
// 打印当前执行目录
console.log('cwd:', process.cwd())
path = `../../outputs/txt2img-samples/samples/${event.context.params.id}`
path = `../outputs/txt2img-samples/samples/${event.context.params.id}`
}
console.log('path:', path)
return fs.readFileSync(path)