python
This commit is contained in:
36
server.py
36
server.py
@@ -210,6 +210,40 @@ def put_watermark(img, wm_encoder=None):
|
||||
img = Image.fromarray(img[:, :, ::-1])
|
||||
return img
|
||||
|
||||
import time
|
||||
import requests
|
||||
|
||||
# 获取model, 如果和之前的model不一样,重新加载
|
||||
def get_model(model_name):
|
||||
global model
|
||||
global config
|
||||
global device
|
||||
if model_name != model_name:
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}", device)
|
||||
return model
|
||||
|
||||
# 使用指定的模型和配置文件进行推理一组参数
|
||||
def drawing(model_name):
|
||||
model = get_model(model_name)
|
||||
if opt.plms:
|
||||
sampler = PLMSSampler(model, device=device)
|
||||
elif opt.dpm:
|
||||
sampler = DPMSolverSampler(model, device=device)
|
||||
else:
|
||||
sampler = DDIMSampler(model, device=device)
|
||||
|
||||
|
||||
def main_dev(opt):
|
||||
while True:
|
||||
time.sleep(1) # 延时1s执行, 避免cpu占用过高
|
||||
# 从局域网中获取一组参数
|
||||
request = requests.get("http://localhost:3000/api/drawing")
|
||||
if request.status_code == 200:
|
||||
data = request.json()
|
||||
print("data: ", data)
|
||||
#drawing("model_name")
|
||||
|
||||
def main(opt):
|
||||
seed_everything(opt.seed)
|
||||
@@ -385,4 +419,4 @@ def main(opt):
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = parse_args()
|
||||
main(opt)
|
||||
main_dev(opt)
|
||||
|
Reference in New Issue
Block a user