python
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,3 +1,4 @@
|
|||||||
|
# javascript
|
||||||
node_modules
|
node_modules
|
||||||
*.log*
|
*.log*
|
||||||
.nuxt
|
.nuxt
|
||||||
@@ -6,3 +7,6 @@ node_modules
|
|||||||
.output
|
.output
|
||||||
.env
|
.env
|
||||||
dist
|
dist
|
||||||
|
|
||||||
|
# python
|
||||||
|
venv
|
||||||
|
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
certifi==2022.12.7
|
||||||
|
charset-normalizer==3.0.1
|
||||||
|
idna==3.4
|
||||||
|
requests==2.28.2
|
||||||
|
urllib3==1.26.14
|
36
server.py
36
server.py
@@ -210,6 +210,40 @@ def put_watermark(img, wm_encoder=None):
|
|||||||
img = Image.fromarray(img[:, :, ::-1])
|
img = Image.fromarray(img[:, :, ::-1])
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# 获取model, 如果和之前的model不一样,重新加载
|
||||||
|
def get_model(model_name):
|
||||||
|
global model
|
||||||
|
global config
|
||||||
|
global device
|
||||||
|
if model_name != model_name:
|
||||||
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
|
device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
|
||||||
|
model = load_model_from_config(config, f"{opt.ckpt}", device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
# 使用指定的模型和配置文件进行推理一组参数
|
||||||
|
def drawing(model_name):
|
||||||
|
model = get_model(model_name)
|
||||||
|
if opt.plms:
|
||||||
|
sampler = PLMSSampler(model, device=device)
|
||||||
|
elif opt.dpm:
|
||||||
|
sampler = DPMSolverSampler(model, device=device)
|
||||||
|
else:
|
||||||
|
sampler = DDIMSampler(model, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def main_dev(opt):
|
||||||
|
while True:
|
||||||
|
time.sleep(1) # 延时1s执行, 避免cpu占用过高
|
||||||
|
# 从局域网中获取一组参数
|
||||||
|
request = requests.get("http://localhost:3000/api/drawing")
|
||||||
|
if request.status_code == 200:
|
||||||
|
data = request.json()
|
||||||
|
print("data: ", data)
|
||||||
|
#drawing("model_name")
|
||||||
|
|
||||||
def main(opt):
|
def main(opt):
|
||||||
seed_everything(opt.seed)
|
seed_everything(opt.seed)
|
||||||
@@ -385,4 +419,4 @@ def main(opt):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
opt = parse_args()
|
opt = parse_args()
|
||||||
main(opt)
|
main_dev(opt)
|
||||||
|
Reference in New Issue
Block a user