DEBUG
This commit is contained in:
		
							
								
								
									
										85
									
								
								server.py
									
									
									
									
									
								
							
							
						
						
									
										85
									
								
								server.py
									
									
									
									
									
								
							@@ -1,4 +1,7 @@
 | 
				
			|||||||
import argparse, os
 | 
					import time
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
import cv2
 | 
					import cv2
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
@@ -20,6 +23,7 @@ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
torch.set_grad_enabled(False)
 | 
					torch.set_grad_enabled(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def chunk(it, size):
 | 
					def chunk(it, size):
 | 
				
			||||||
    it = iter(it)
 | 
					    it = iter(it)
 | 
				
			||||||
    return iter(lambda: tuple(islice(it, size)), ())
 | 
					    return iter(lambda: tuple(islice(it, size)), ())
 | 
				
			||||||
@@ -210,41 +214,44 @@ def put_watermark(img, wm_encoder=None):
 | 
				
			|||||||
        img = Image.fromarray(img[:, :, ::-1])
 | 
					        img = Image.fromarray(img[:, :, ::-1])
 | 
				
			||||||
    return img
 | 
					    return img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import time
 | 
					 | 
				
			||||||
import requests
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 对任务状态的修改
 | 
					# 对任务状态的修改
 | 
				
			||||||
def update_task_status(task:dict, status:str, progress:int):
 | 
					def update_task_status(task: dict, status: str, progress: int):
 | 
				
			||||||
    task["status"] = status
 | 
					    task["status"] = status
 | 
				
			||||||
    task["progress"] = progress
 | 
					    task["progress"] = progress
 | 
				
			||||||
    requests.put(f"http://localhost:3000/api/drawing/{task['id']}", json=task)
 | 
					    requests.put(f"http://localhost:3000/api/drawing/{task['id']}", json=task)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main_dev(opt):
 | 
					def main_dev(opt):
 | 
				
			||||||
    model_name = ''   # 默认模型
 | 
					    model_name = '' # 默认模型
 | 
				
			||||||
    model      = None # 默认模型
 | 
					    model = None    # 默认模型
 | 
				
			||||||
    config     = None # 默认配置
 | 
					    config = None   # 默认配置
 | 
				
			||||||
    device     = None # 默认设备
 | 
					    device = None   # 默认设备
 | 
				
			||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        time.sleep(2)                                                   # 延时1s执行, 避免cpu占用过高
 | 
					        time.sleep(2) # 延时1s执行, 避免cpu占用过高
 | 
				
			||||||
        data = requests.get("http://localhost:3000/api/drawing").json() # 从局域网中获取一组参数
 | 
					        data = requests.get("http://localhost:3000/api/drawing").json()  # 从局域网中获取一组参数
 | 
				
			||||||
        print(data) # [{'model': '768-v-ema', 'prompt': '一只猫', 'watermark': '0'}, {'model': '768-v-ema', 'prompt': '一只狗', 'watermark': '0'}]
 | 
					        print(data)
 | 
				
			||||||
        # 遍历 data 返回dict
 | 
					        # 遍历 data 返回dict
 | 
				
			||||||
        for item in data:
 | 
					        for item in data:
 | 
				
			||||||
            print(item)
 | 
					            print(item)
 | 
				
			||||||
            update_task_status(item, "running", 0) # 更新任务状态为运行中
 | 
					            update_task_status(item, "running", 0)  # 更新任务状态为运行中
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # 设置参数
 | 
					            # 设置参数
 | 
				
			||||||
            if 'prompt'    in item: opt.prompt    = item['prompt']    # 描述
 | 
					            if 'prompt' in item:
 | 
				
			||||||
            if 'n_samples' in item: opt.n_samples = item['n_samples'] # 列数
 | 
					                opt.prompt = item['prompt']    # 描述
 | 
				
			||||||
            if 'n_rows'    in item: opt.n_rows    = item['n_rows']    # 行数
 | 
					            if 'n_samples' in item:
 | 
				
			||||||
            if 'scale'     in item: opt.scale     = item['scale']     # 比例
 | 
					                opt.n_samples = item['n_samples']  # 列数
 | 
				
			||||||
 | 
					            if 'n_rows' in item:
 | 
				
			||||||
 | 
					                opt.n_rows = item['n_rows']    # 行数
 | 
				
			||||||
 | 
					            if 'scale' in item:
 | 
				
			||||||
 | 
					                opt.scale = item['scale']     # 比例
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # 如果模型不同,重新加载模型(注意释放内存)
 | 
					            # 如果模型不同,重新加载模型(注意释放内存)
 | 
				
			||||||
            if item['ckpt'] != model_name:
 | 
					            if item['ckpt'] != model_name:
 | 
				
			||||||
                # 获取环境配置
 | 
					                # 获取环境配置
 | 
				
			||||||
                model_name = item['ckpt']
 | 
					                model_name = item['ckpt']
 | 
				
			||||||
                opt.config = f'/data/{model_name}.yaml'
 | 
					                opt.config = f'/data/{model_name}.yaml'
 | 
				
			||||||
                opt.ckpt   = f'/data/{model_name}.ckpt'
 | 
					                opt.ckpt = f'/data/{model_name}.ckpt'
 | 
				
			||||||
                opt.device = 'cuda'
 | 
					                opt.device = 'cuda'
 | 
				
			||||||
                print(f"config: {opt.config}", f"ckpt: {opt.ckpt}", f"device: {opt.device}")
 | 
					                print(f"config: {opt.config}", f"ckpt: {opt.ckpt}", f"device: {opt.device}")
 | 
				
			||||||
                config = OmegaConf.load(f"{opt.config}")
 | 
					                config = OmegaConf.load(f"{opt.config}")
 | 
				
			||||||
@@ -300,8 +307,7 @@ def main_dev(opt):
 | 
				
			|||||||
                if opt.bf16 and not opt.torchscript and not opt.ipex:
 | 
					                if opt.bf16 and not opt.torchscript and not opt.ipex:
 | 
				
			||||||
                    raise ValueError('Bfloat16 is supported only for torchscript+ipex')
 | 
					                    raise ValueError('Bfloat16 is supported only for torchscript+ipex')
 | 
				
			||||||
                if opt.bf16 and unet.dtype != torch.bfloat16:
 | 
					                if opt.bf16 and unet.dtype != torch.bfloat16:
 | 
				
			||||||
                    raise ValueError("Use configs/stable-diffusion/intel/ configs with bf16 enabled if " +
 | 
					                    raise ValueError("Use configs/stable-diffusion/intel/ configs with bf16 enabled if you'd like to use bfloat16 with CPU.")
 | 
				
			||||||
                                     "you'd like to use bfloat16 with CPU.")
 | 
					 | 
				
			||||||
                if unet.dtype == torch.float16 and device == torch.device("cpu"):
 | 
					                if unet.dtype == torch.float16 and device == torch.device("cpu"):
 | 
				
			||||||
                    raise ValueError("Use configs/stable-diffusion/intel/ configs for your model if you'd like to run it on CPU.")
 | 
					                    raise ValueError("Use configs/stable-diffusion/intel/ configs for your model if you'd like to run it on CPU.")
 | 
				
			||||||
                if opt.ipex:
 | 
					                if opt.ipex:
 | 
				
			||||||
@@ -317,8 +323,7 @@ def main_dev(opt):
 | 
				
			|||||||
                    with torch.no_grad(), additional_context:
 | 
					                    with torch.no_grad(), additional_context:
 | 
				
			||||||
                        # get UNET scripted
 | 
					                        # get UNET scripted
 | 
				
			||||||
                        if unet.use_checkpoint:
 | 
					                        if unet.use_checkpoint:
 | 
				
			||||||
                            raise ValueError("Gradient checkpoint won't work with tracing. " +
 | 
					                            raise ValueError("Gradient checkpoint won't work with tracing. Use configs/stable-diffusion/intel/ configs for your model or disable checkpoint in your config.")
 | 
				
			||||||
                            "Use configs/stable-diffusion/intel/ configs for your model or disable checkpoint in your config.")
 | 
					 | 
				
			||||||
                        img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32)
 | 
					                        img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32)
 | 
				
			||||||
                        t_in = torch.ones(2, dtype=torch.int64)
 | 
					                        t_in = torch.ones(2, dtype=torch.int64)
 | 
				
			||||||
                        context = torch.ones(2, 77, 1024, dtype=torch.float32)
 | 
					                        context = torch.ones(2, 77, 1024, dtype=torch.float32)
 | 
				
			||||||
@@ -354,9 +359,9 @@ def main_dev(opt):
 | 
				
			|||||||
                    print("Running a forward pass for decoder")
 | 
					                    print("Running a forward pass for decoder")
 | 
				
			||||||
                    for _ in range(3):
 | 
					                    for _ in range(3):
 | 
				
			||||||
                        x_samples_ddim = model.decode_first_stage(samples_ddim)
 | 
					                        x_samples_ddim = model.decode_first_stage(samples_ddim)
 | 
				
			||||||
            precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext
 | 
					            precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext
 | 
				
			||||||
            with torch.no_grad(), precision_scope(opt.device), model.ema_scope():
 | 
					            with torch.no_grad(), precision_scope(opt.device), model.ema_scope():
 | 
				
			||||||
                all_samples = list()
 | 
					                #all_samples = list()
 | 
				
			||||||
                # 执行指定的次数
 | 
					                # 执行指定的次数
 | 
				
			||||||
                for n in trange(item['number'], desc="Sampling"):
 | 
					                for n in trange(item['number'], desc="Sampling"):
 | 
				
			||||||
                    print("Sampling:", n)
 | 
					                    print("Sampling:", n)
 | 
				
			||||||
@@ -369,26 +374,27 @@ def main_dev(opt):
 | 
				
			|||||||
                        c = model.get_learned_conditioning(prompts)
 | 
					                        c = model.get_learned_conditioning(prompts)
 | 
				
			||||||
                        shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
 | 
					                        shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
 | 
				
			||||||
                        samples, _ = sampler.sample(S=opt.steps,
 | 
					                        samples, _ = sampler.sample(S=opt.steps,
 | 
				
			||||||
                                                            conditioning=c,
 | 
					                                                    conditioning=c,
 | 
				
			||||||
                                                            batch_size=opt.n_samples,
 | 
					                                                    batch_size=opt.n_samples,
 | 
				
			||||||
                                                            shape=shape,
 | 
					                                                    shape=shape,
 | 
				
			||||||
                                                            verbose=False,
 | 
					                                                    verbose=False,
 | 
				
			||||||
                                                            unconditional_guidance_scale=opt.scale,
 | 
					                                                    unconditional_guidance_scale=opt.scale,
 | 
				
			||||||
                                                            unconditional_conditioning=uc,
 | 
					                                                    unconditional_conditioning=uc,
 | 
				
			||||||
                                                            eta=opt.ddim_eta,
 | 
					                                                    eta=opt.ddim_eta,
 | 
				
			||||||
                                                            x_T=start_code)
 | 
					                                                    x_T=start_code)
 | 
				
			||||||
                        x_samples = model.decode_first_stage(samples)
 | 
					                        x_samples = model.decode_first_stage(samples)
 | 
				
			||||||
                        x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
 | 
					                        x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
 | 
				
			||||||
                        for x_sample in x_samples:
 | 
					                        for x_sample in x_samples:
 | 
				
			||||||
 | 
					                            print("Sample count:", sample_count)
 | 
				
			||||||
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
 | 
					                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
 | 
				
			||||||
                            img = Image.fromarray(x_sample.astype(np.uint8))
 | 
					                            img = Image.fromarray(x_sample.astype(np.uint8))
 | 
				
			||||||
                            img = put_watermark(img, wm_encoder)
 | 
					                            img = put_watermark(img, wm_encoder)
 | 
				
			||||||
                            img.save(os.path.join(sample_path, f"{base_count:05}.png"))
 | 
					                            img.save(os.path.join(sample_path, f"{base_count:05}.png"))
 | 
				
			||||||
                            base_count += 1
 | 
					                            base_count += 1
 | 
				
			||||||
                            sample_count += 1
 | 
					                            sample_count += 1
 | 
				
			||||||
                        all_samples.append(x_samples)
 | 
					                        #all_samples.append(x_samples)
 | 
				
			||||||
                        print("Sample count:", sample_count)
 | 
					                        print("Sample count:", sample_count)
 | 
				
			||||||
                #for n in trange(opt.n_iter, desc="Sampling"):
 | 
					                # for n in trange(opt.n_iter, desc="Sampling"):
 | 
				
			||||||
                #    for prompts in tqdm(data, desc="data"):
 | 
					                #    for prompts in tqdm(data, desc="data"):
 | 
				
			||||||
                #        uc = None
 | 
					                #        uc = None
 | 
				
			||||||
                #        if opt.scale != 1.0:
 | 
					                #        if opt.scale != 1.0:
 | 
				
			||||||
@@ -416,24 +422,23 @@ def main_dev(opt):
 | 
				
			|||||||
                #            base_count += 1
 | 
					                #            base_count += 1
 | 
				
			||||||
                #            sample_count += 1
 | 
					                #            sample_count += 1
 | 
				
			||||||
                #        all_samples.append(x_samples)
 | 
					                #        all_samples.append(x_samples)
 | 
				
			||||||
                ## additionally, save as grid
 | 
					                # additionally, save as grid
 | 
				
			||||||
                #grid = torch.stack(all_samples, 0)
 | 
					                #grid = torch.stack(all_samples, 0)
 | 
				
			||||||
                #grid = rearrange(grid, 'n b c h w -> (n b) c h w')
 | 
					                #grid = rearrange(grid, 'n b c h w -> (n b) c h w')
 | 
				
			||||||
                #grid = make_grid(grid, nrow=n_rows)
 | 
					                #grid = make_grid(grid, nrow=n_rows)
 | 
				
			||||||
                ## to image
 | 
					                # to image
 | 
				
			||||||
                #grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
 | 
					                #grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
 | 
				
			||||||
                #grid = Image.fromarray(grid.astype(np.uint8))
 | 
					                #grid = Image.fromarray(grid.astype(np.uint8))
 | 
				
			||||||
                #grid = put_watermark(grid, wm_encoder)
 | 
					                #grid = put_watermark(grid, wm_encoder)
 | 
				
			||||||
                #grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
 | 
					                #grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
 | 
				
			||||||
                #grid_count += 1
 | 
					                #grid_count += 1
 | 
				
			||||||
            print(f"Your samples are ready and waiting for you here: \n{outpath} \n", f" \nEnjoy.")
 | 
					            print(f"Your samples are ready and waiting for you here: \n{outpath} \n", f" \nEnjoy.")
 | 
				
			||||||
            # 修改任务状态为完成
 | 
					            update_task_status(task=item, status='done', progress=1) # 修改任务状态为完成
 | 
				
			||||||
            update_task_status(task=item, status='done', progress=1)
 | 
					        print("任务结束, 等待10s后退出..")
 | 
				
			||||||
        # 任务结束, 等待20s后退出
 | 
					        time.sleep(10)
 | 
				
			||||||
        print("任务结束, 等待20s后退出..")
 | 
					 | 
				
			||||||
        time.sleep(20)
 | 
					 | 
				
			||||||
        break
 | 
					        break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    opt = parse_args()
 | 
					    opt = parse_args()
 | 
				
			||||||
    main_dev(opt)
 | 
					    main_dev(opt)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user