python
This commit is contained in:
		
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -1,3 +1,4 @@
 | 
				
			|||||||
 | 
					# javascript
 | 
				
			||||||
node_modules
 | 
					node_modules
 | 
				
			||||||
*.log*
 | 
					*.log*
 | 
				
			||||||
.nuxt
 | 
					.nuxt
 | 
				
			||||||
@@ -6,3 +7,6 @@ node_modules
 | 
				
			|||||||
.output
 | 
					.output
 | 
				
			||||||
.env
 | 
					.env
 | 
				
			||||||
dist
 | 
					dist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# python
 | 
				
			||||||
 | 
					venv
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										5
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					certifi==2022.12.7
 | 
				
			||||||
 | 
					charset-normalizer==3.0.1
 | 
				
			||||||
 | 
					idna==3.4
 | 
				
			||||||
 | 
					requests==2.28.2
 | 
				
			||||||
 | 
					urllib3==1.26.14
 | 
				
			||||||
							
								
								
									
										36
									
								
								server.py
									
									
									
									
									
								
							
							
						
						
									
										36
									
								
								server.py
									
									
									
									
									
								
							@@ -210,6 +210,40 @@ def put_watermark(img, wm_encoder=None):
 | 
				
			|||||||
        img = Image.fromarray(img[:, :, ::-1])
 | 
					        img = Image.fromarray(img[:, :, ::-1])
 | 
				
			||||||
    return img
 | 
					    return img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 获取model, 如果和之前的model不一样,重新加载
 | 
				
			||||||
 | 
					def get_model(model_name):
 | 
				
			||||||
 | 
					    global model
 | 
				
			||||||
 | 
					    global config
 | 
				
			||||||
 | 
					    global device
 | 
				
			||||||
 | 
					    if model_name != model_name:
 | 
				
			||||||
 | 
					        config = OmegaConf.load(f"{opt.config}")
 | 
				
			||||||
 | 
					        device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
 | 
				
			||||||
 | 
					        model = load_model_from_config(config, f"{opt.ckpt}", device)
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 使用指定的模型和配置文件进行推理一组参数
 | 
				
			||||||
 | 
					def drawing(model_name):
 | 
				
			||||||
 | 
					    model = get_model(model_name)
 | 
				
			||||||
 | 
					    if opt.plms:
 | 
				
			||||||
 | 
					        sampler = PLMSSampler(model, device=device)
 | 
				
			||||||
 | 
					    elif opt.dpm:
 | 
				
			||||||
 | 
					        sampler = DPMSolverSampler(model, device=device)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        sampler = DDIMSampler(model, device=device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main_dev(opt):
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        time.sleep(1) # 延时1s执行, 避免cpu占用过高
 | 
				
			||||||
 | 
					        # 从局域网中获取一组参数
 | 
				
			||||||
 | 
					        request = requests.get("http://localhost:3000/api/drawing")
 | 
				
			||||||
 | 
					        if request.status_code == 200:
 | 
				
			||||||
 | 
					            data = request.json()
 | 
				
			||||||
 | 
					            print("data: ", data)
 | 
				
			||||||
 | 
					        #drawing("model_name")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main(opt):
 | 
					def main(opt):
 | 
				
			||||||
    seed_everything(opt.seed)
 | 
					    seed_everything(opt.seed)
 | 
				
			||||||
@@ -385,4 +419,4 @@ def main(opt):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    opt = parse_args()
 | 
					    opt = parse_args()
 | 
				
			||||||
    main(opt)
 | 
					    main_dev(opt)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user