This commit is contained in:
2023-02-14 03:12:11 +08:00
parent 8e581e3a56
commit e2e229398f
3 changed files with 44 additions and 1 deletions

4
.gitignore vendored
View File

@@ -1,3 +1,4 @@
# javascript
node_modules node_modules
*.log* *.log*
.nuxt .nuxt
@@ -6,3 +7,6 @@ node_modules
.output .output
.env .env
dist dist
# python
venv

5
requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
certifi==2022.12.7
charset-normalizer==3.0.1
idna==3.4
requests==2.28.2
urllib3==1.26.14

View File

@@ -210,6 +210,40 @@ def put_watermark(img, wm_encoder=None):
img = Image.fromarray(img[:, :, ::-1]) img = Image.fromarray(img[:, :, ::-1])
return img return img
import time
import requests
# 获取model, 如果和之前的model不一样重新加载
def get_model(model_name):
global model
global config
global device
if model_name != model_name:
config = OmegaConf.load(f"{opt.config}")
device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
model = load_model_from_config(config, f"{opt.ckpt}", device)
return model
# 使用指定的模型和配置文件进行推理一组参数
def drawing(model_name):
model = get_model(model_name)
if opt.plms:
sampler = PLMSSampler(model, device=device)
elif opt.dpm:
sampler = DPMSolverSampler(model, device=device)
else:
sampler = DDIMSampler(model, device=device)
def main_dev(opt):
while True:
time.sleep(1) # 延时1s执行, 避免cpu占用过高
# 从局域网中获取一组参数
request = requests.get("http://localhost:3000/api/drawing")
if request.status_code == 200:
data = request.json()
print("data: ", data)
#drawing("model_name")
def main(opt): def main(opt):
seed_everything(opt.seed) seed_everything(opt.seed)
@@ -385,4 +419,4 @@ def main(opt):
if __name__ == "__main__": if __name__ == "__main__":
opt = parse_args() opt = parse_args()
main(opt) main_dev(opt)