From 87a9c5bd3546b9c2fd703c07233f731376ab32e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A1=9C=E8=8F=AF?= Date: Sat, 18 Feb 2023 23:50:16 +0800 Subject: [PATCH] =?UTF-8?q?py=E9=80=92=E5=BD=92=E6=9C=AA=E4=BC=98=E5=8C=96?= =?UTF-8?q?,=20=E6=9B=B4=E6=8D=A2=E4=B8=BAwhile?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server.py | 176 ++++++++++++------------------------------------------ venv | 1 - 2 files changed, 37 insertions(+), 140 deletions(-) delete mode 120000 venv diff --git a/server.py b/server.py index a6149d4..b1caafe 100644 --- a/server.py +++ b/server.py @@ -216,11 +216,21 @@ def put_watermark(img, wm_encoder=None): # 对任务状态的修改 -def update_task_status(task: dict, status: str, progress: int): +def update_task_status(task: dict, status: str, progress: int, data: list = []): task["status"] = status task["progress"] = progress + task["data"] = data requests.put(f"http://localhost:3000/api/drawing/{task['id']}", json=task) +# 从局域网中获取一组任务(如果列表为空,等待2s后重新获取) +def get_tasks(): + tasks = requests.get("http://localhost:3000/api/drawing?status=waiting").json() + if len(tasks) == 0: + while len(tasks) == 0: + print('no task, wait 2s...') + time.sleep(2) + tasks = requests.get("http://localhost:3000/api/drawing?status=waiting").json() + return tasks def main_dev(opt): model_name = '' # 默认模型 @@ -228,29 +238,14 @@ def main_dev(opt): config = None # 默认配置 device = None # 默认设备 while True: - time.sleep(2) # 延时1s执行, 避免cpu占用过高 - data = requests.get("http://localhost:3000/api/drawing").json() # 从局域网中获取一组参数 - print(data) - # 遍历 data 返回dict - for item in data: - print(item) - update_task_status(item, "running", 0) # 更新任务状态为运行中 - - # 设置参数 - if 'prompt' in item: - opt.prompt = item['prompt'] # 描述 - if 'number' in item: - opt.n_samples = item['number'] # 列数 - print(f"n_samples: {opt.n_samples}") - #if 'n_rows' in item: - # opt.n_rows = item['n_rows'] # 行数 - if 'scale' in item: - opt.scale = item['scale'] # 比例 + for task in get_tasks(): # 遍历 tasks 返回 dict + print('task:', task) # 打印任务 + update_task_status(task, "running", 0) # 更新任务状态为运行中 # 如果模型不同,重新加载模型(注意释放内存) - if item['ckpt'] != model_name: + if task['ckpt'] != model_name: # 获取环境配置 - model_name = item['ckpt'] + model_name = task['ckpt'] opt.config = f'/data/{model_name}.yaml' opt.ckpt = f'/data/{model_name}.ckpt' opt.device = 'cuda' @@ -261,6 +256,7 @@ def main_dev(opt): print(f"加载模型到显存: {model_name}..") model = load_model_from_config(config, f"{opt.ckpt}", device) print(f"加载到显存完成: {model_name}") + # 使用指定的模型和配置文件进行推理一组参数 if opt.plms: sampler = PLMSSampler(model, device=device) @@ -268,104 +264,46 @@ def main_dev(opt): sampler = DPMSolverSampler(model, device=device) else: sampler = DDIMSampler(model, device=device) + # 检查输出目录是否存在 os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir + # 创建水印编码器 wm = "SDV2" wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + # x - batch_size = opt.n_samples - #n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + batch_size = task['number'] + if not opt.from_file: - prompt = opt.prompt + prompt = task['prompt'] assert prompt is not None data = [batch_size * [prompt]] + print("data:", data) else: print(f"reading prompts from {opt.from_file}") with open(opt.from_file, "r") as f: data = f.read().splitlines() data = [p for p in data for i in range(opt.repeat)] data = list(chunk(data, batch_size)) + print("data:", data) # x sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) sample_count = 0 base_count = len(os.listdir(sample_path)) - grid_count = len(os.listdir(outpath)) - 1 + # x start_code = None if opt.fixed_code: - start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) - - ''' - # 切换模型 - if opt.torchscript or opt.ipex: - transformer = model.cond_stage_model.model - unet = model.model.diffusion_model - decoder = model.first_stage_model.decoder - additional_context = torch.cpu.amp.autocast() if opt.bf16 else nullcontext() - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - if opt.bf16 and not opt.torchscript and not opt.ipex: - raise ValueError('Bfloat16 is supported only for torchscript+ipex') - if opt.bf16 and unet.dtype != torch.bfloat16: - raise ValueError("Use configs/stable-diffusion/intel/ configs with bf16 enabled if you'd like to use bfloat16 with CPU.") - if unet.dtype == torch.float16 and device == torch.device("cpu"): - raise ValueError("Use configs/stable-diffusion/intel/ configs for your model if you'd like to run it on CPU.") - if opt.ipex: - import intel_extension_for_pytorch as ipex - bf16_dtype = torch.bfloat16 if opt.bf16 else None - transformer = transformer.to(memory_format=torch.channels_last) - transformer = ipex.optimize(transformer, level="O1", inplace=True) - unet = unet.to(memory_format=torch.channels_last) - unet = ipex.optimize(unet, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype) - decoder = decoder.to(memory_format=torch.channels_last) - decoder = ipex.optimize(decoder, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype) - if opt.torchscript: - with torch.no_grad(), additional_context: - # get UNET scripted - if unet.use_checkpoint: - raise ValueError("Gradient checkpoint won't work with tracing. Use configs/stable-diffusion/intel/ configs for your model or disable checkpoint in your config.") - img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32) - t_in = torch.ones(2, dtype=torch.int64) - context = torch.ones(2, 77, 1024, dtype=torch.float32) - scripted_unet = torch.jit.trace(unet, (img_in, t_in, context)) - scripted_unet = torch.jit.optimize_for_inference(scripted_unet) - print(type(scripted_unet)) - model.model.scripted_diffusion_model = scripted_unet - # get Decoder for first stage model scripted - samples_ddim = torch.ones(1, 4, 96, 96, dtype=torch.float32) - scripted_decoder = torch.jit.trace(decoder, (samples_ddim)) - scripted_decoder = torch.jit.optimize_for_inference(scripted_decoder) - print(type(scripted_decoder)) - model.first_stage_model.decoder = scripted_decoder - prompts = data[0] - print("Running a forward pass to initialize optimizations") - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - with torch.no_grad(), additional_context: - for _ in range(3): - c = model.get_learned_conditioning(prompts) - samples_ddim, _ = sampler.sample(S=5, - conditioning=c, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) - print("Running a forward pass for decoder") - for _ in range(3): - x_samples_ddim = model.decode_first_stage(samples_ddim) + start_code = torch.randn([task['number'], opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + # 生成图片 precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext with torch.no_grad(), precision_scope(opt.device), model.ema_scope(): - #all_samples = list() - # 执行指定的任务批次 (row)(item['number']) + images = [] + # 执行指定的任务批次 (row)(task['number']) for n in trange(1, desc="Sampling"): print("Sampling:", data) for prompts in tqdm(data, desc="data"): @@ -378,7 +316,7 @@ def main_dev(opt): shape = [opt.C, opt.H // opt.f, opt.W // opt.f] samples, _ = sampler.sample(S=opt.steps, conditioning=c, - batch_size=opt.n_samples, + batch_size=task['number'], shape=shape, verbose=False, unconditional_guidance_scale=opt.scale, @@ -389,58 +327,18 @@ def main_dev(opt): x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples: print("Sample count:", sample_count) + imge_path = os.path.join(sample_path, f"{base_count:05}.png") x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') img = Image.fromarray(x_sample.astype(np.uint8)) img = put_watermark(img, wm_encoder) - img.save(os.path.join(sample_path, f"{base_count:05}.png")) + img.save(imge_path) base_count += 1 sample_count += 1 - #all_samples.append(x_samples) + images.append(imge_path) print("Sample count:", sample_count) - # for n in trange(opt.n_iter, desc="Sampling"): - # for prompts in tqdm(data, desc="data"): - # uc = None - # if opt.scale != 1.0: - # uc = model.get_learned_conditioning(batch_size * [""]) - # if isinstance(prompts, tuple): - # prompts = list(prompts) - # c = model.get_learned_conditioning(prompts) - # shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - # samples, _ = sampler.sample(S=opt.steps, - # conditioning=c, - # batch_size=opt.n_samples, - # shape=shape, - # verbose=False, - # unconditional_guidance_scale=opt.scale, - # unconditional_conditioning=uc, - # eta=opt.ddim_eta, - # x_T=start_code) - # x_samples = model.decode_first_stage(samples) - # x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - # for x_sample in x_samples: - # x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - # img = Image.fromarray(x_sample.astype(np.uint8)) - # img = put_watermark(img, wm_encoder) - # img.save(os.path.join(sample_path, f"{base_count:05}.png")) - # base_count += 1 - # sample_count += 1 - # all_samples.append(x_samples) - # additionally, save as grid - #grid = torch.stack(all_samples, 0) - #grid = rearrange(grid, 'n b c h w -> (n b) c h w') - #grid = make_grid(grid, nrow=n_rows) - # to image - #grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - #grid = Image.fromarray(grid.astype(np.uint8)) - #grid = put_watermark(grid, wm_encoder) - #grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - #grid_count += 1 - print(f"Your samples are ready and waiting for you here: \n{outpath} \n", f" \nEnjoy.") - update_task_status(task=item, status='done', progress=1) # 修改任务状态为完成 - ''' - print("任务结束, 等待10s后退出..") - #time.sleep(10) - break + update_task_status(task=task, status='done', progress=1, data=images) # 修改任务状态为完成 + print("批次任务结束..") + #break if __name__ == "__main__": diff --git a/venv b/venv deleted file mode 120000 index de9d3d0..0000000 --- a/venv +++ /dev/null @@ -1 +0,0 @@ -/data/stablediffusion/venv \ No newline at end of file