diff --git a/server.py b/server.py index 2e047d3..6118fac 100644 --- a/server.py +++ b/server.py @@ -264,9 +264,15 @@ def main_dev(opt): update_task_status(task, "init", 0) # 获取环境配置 model_name = task['ckpt'] - opt.config = f'/data/{model_name}.yaml' - opt.ckpt = f'/data/{model_name}.ckpt' + opt.config = f'/data/ckpt/{model_name}.yaml' + opt.ckpt = f'/data/ckpt/{model_name}.ckpt' opt.device = 'cuda' + + # 检查yaml文件是否存在 + if not os.path.exists(opt.config): + print(f"yaml文件不存在: {opt.config}, 将使用默认配置") + opt.config = "/data/stablediffusion/configs/stable-diffusion/v2-inference.yaml" + print(f"config: {opt.config}", f"ckpt: {opt.ckpt}", f"device: {opt.device}") config = OmegaConf.load(f"{opt.config}") device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")