diff --git a/.gitignore b/.gitignore index 438cb08..0041888 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +# javascript node_modules *.log* .nuxt @@ -6,3 +7,6 @@ node_modules .output .env dist + +# python +venv diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fd0532b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +certifi==2022.12.7 +charset-normalizer==3.0.1 +idna==3.4 +requests==2.28.2 +urllib3==1.26.14 diff --git a/server.py b/server.py index 9d955e3..8d9269b 100644 --- a/server.py +++ b/server.py @@ -210,6 +210,40 @@ def put_watermark(img, wm_encoder=None): img = Image.fromarray(img[:, :, ::-1]) return img +import time +import requests + +# 获取model, 如果和之前的model不一样,重新加载 +def get_model(model_name): + global model + global config + global device + if model_name != model_name: + config = OmegaConf.load(f"{opt.config}") + device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu") + model = load_model_from_config(config, f"{opt.ckpt}", device) + return model + +# 使用指定的模型和配置文件进行推理一组参数 +def drawing(model_name): + model = get_model(model_name) + if opt.plms: + sampler = PLMSSampler(model, device=device) + elif opt.dpm: + sampler = DPMSolverSampler(model, device=device) + else: + sampler = DDIMSampler(model, device=device) + + +def main_dev(opt): + while True: + time.sleep(1) # 延时1s执行, 避免cpu占用过高 + # 从局域网中获取一组参数 + request = requests.get("http://localhost:3000/api/drawing") + if request.status_code == 200: + data = request.json() + print("data: ", data) + #drawing("model_name") def main(opt): seed_everything(opt.seed) @@ -385,4 +419,4 @@ def main(opt): if __name__ == "__main__": opt = parse_args() - main(opt) + main_dev(opt)