diff --git a/server.py b/server.py index 45cc34f..eda9cb1 100644 --- a/server.py +++ b/server.py @@ -229,7 +229,8 @@ def get_tasks(tasks:list=[]): tasks = requests.get("http://localhost:3000/api/drawing?status=waiting").json() if len(tasks) == 0: time.sleep(2) except: - print("get tasks error") + # 打印当前时间 + print("get tasks error", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) time.sleep(2) return tasks @@ -241,10 +242,11 @@ def main_dev(opt): while True: for task in get_tasks(): # 遍历 tasks 返回 dict print('task:', task) # 打印任务 - update_task_status(task, "running", 0) # 更新任务状态为运行中 # 如果模型不同,重新加载模型(注意释放内存) if task['ckpt'] != model_name: + # 修改状态为加载模型 + update_task_status(task, "loading", 0) # 获取环境配置 model_name = task['ckpt'] opt.config = f'/data/{model_name}.yaml' @@ -257,6 +259,9 @@ def main_dev(opt): print(f"加载模型到显存: {model_name}..") model = load_model_from_config(config, f"{opt.ckpt}", device) print(f"加载到显存完成: {model_name}") + + # 更新任务状态为运行中 + update_task_status(task, "running", 0) # 使用指定的模型和配置文件进行推理一组参数 if opt.plms: @@ -300,13 +305,16 @@ def main_dev(opt): start_code = None if opt.fixed_code: start_code = torch.randn([task['number'], opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + # 更新进度 + update_task_status(task, "running", 0.1) + # 生成图片 precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext with torch.no_grad(), precision_scope(opt.device), model.ema_scope(): images = [] # 执行指定的任务批次 (row)(task['number']) for n in trange(1, desc="Sampling"): - print("Sampling:", data) for prompts in tqdm(data, desc="data"): uc = None if opt.scale != 1.0: @@ -327,7 +335,6 @@ def main_dev(opt): x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples: - print("Sample count:", sample_count) imge_path = os.path.join(sample_path, f"{base_count:05}.png") imge_path = os.path.abspath(imge_path) # 转换为绝对路径 x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') @@ -337,7 +344,6 @@ def main_dev(opt): base_count += 1 sample_count += 1 images.append(imge_path) - print("Sample count:", sample_count) update_task_status(task=task, status='done', progress=1, data=images) # 修改任务状态为完成 print("批次任务结束..") #break diff --git a/server/api/img/[id].ts b/server/api/img/[id].ts index 450c2f6..15fd9a3 100644 --- a/server/api/img/[id].ts +++ b/server/api/img/[id].ts @@ -8,7 +8,7 @@ export default defineEventHandler(async event => { if (!fs.existsSync('outputs')) { // 打印当前执行目录 console.log('cwd:', process.cwd()) - path = `../../outputs/txt2img-samples/samples/${event.context.params.id}` + path = `../outputs/txt2img-samples/samples/${event.context.params.id}` } console.log('path:', path) return fs.readFileSync(path)