import time import requests import argparse import os import cv2 import torch import numpy as np from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange from itertools import islice from einops import rearrange from torchvision.utils import make_grid from pytorch_lightning import seed_everything from torch import autocast from contextlib import nullcontext from imwatermark import WatermarkEncoder from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler torch.set_grad_enabled(False) def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) if device == torch.device("cuda"): model.cuda() elif device == torch.device("cpu"): model.cpu() model.cond_stage_model.device = "cpu" else: raise ValueError(f"Incorrect device name. Received: {device}") model.eval() return model def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--prompt", type=str, nargs="?", default="a professional photograph of an astronaut riding a triceratops", help="the prompt to render" ) parser.add_argument( "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--steps", type=int, default=50, help="number of ddim sampling steps", ) parser.add_argument( "--plms", action='store_true', help="use plms sampling", ) parser.add_argument( "--dpm", action='store_true', help="use DPM (2) sampler", ) parser.add_argument( "--fixed_code", action='store_true', help="if enabled, uses the same starting code across all samples ", ) parser.add_argument( "--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) parser.add_argument( "--n_iter", type=int, default=3, help="sample this often", ) parser.add_argument( "--H", type=int, default=512, help="image height, in pixel space", ) parser.add_argument( "--W", type=int, default=512, help="image width, in pixel space", ) parser.add_argument( "--C", type=int, default=4, help="latent channels", ) parser.add_argument( "--f", type=int, default=8, help="downsampling factor, most often 8 or 16", ) parser.add_argument( "--n_samples", type=int, default=3, help="how many samples to produce for each given prompt. A.k.a batch size", ) parser.add_argument( "--n_rows", type=int, default=0, help="rows in the grid (default: n_samples)", ) parser.add_argument( "--scale", type=float, default=9.0, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) parser.add_argument( "--from-file", type=str, help="if specified, load prompts from this file, separated by newlines", ) parser.add_argument( "--config", type=str, default="configs/stable-diffusion/v2-inference.yaml", help="path to config which constructs model", ) parser.add_argument( "--ckpt", type=str, help="path to checkpoint of model", ) parser.add_argument( "--seed", type=int, default=42, help="the seed (for reproducible sampling)", ) parser.add_argument( "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" ) parser.add_argument( "--repeat", type=int, default=1, help="repeat each prompt in file this often", ) parser.add_argument( "--device", type=str, help="Device on which Stable Diffusion will be run", choices=["cpu", "cuda"], default="cpu" ) parser.add_argument( "--torchscript", action='store_true', help="Use TorchScript", ) parser.add_argument( "--ipex", action='store_true', help="Use Intel® Extension for PyTorch*", ) parser.add_argument( "--bf16", action='store_true', help="Use bfloat16", ) opt = parser.parse_args() return opt def put_watermark(img, wm_encoder=None): if wm_encoder is not None: img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) img = wm_encoder.encode(img, 'dwtDct') img = Image.fromarray(img[:, :, ::-1]) return img # 对任务状态的修改 def update_task_status(task: dict, status: str, progress: int, data: list = []): task["status"] = status task["progress"] = progress task["data"] = data requests.put(f"http://localhost:3000/api/drawing/{task['id']}", json=task) # 从局域网中获取一组任务(如果列表为空,等待2s后重新获取) def get_tasks(): tasks = requests.get("http://localhost:3000/api/drawing?status=waiting").json() if len(tasks) == 0: while len(tasks) == 0: print('no task, wait 2s...') time.sleep(2) tasks = requests.get("http://localhost:3000/api/drawing?status=waiting").json() return tasks def main_dev(opt): model_name = '' # 默认模型 model = None # 默认模型 config = None # 默认配置 device = None # 默认设备 while True: for task in get_tasks(): # 遍历 tasks 返回 dict print('task:', task) # 打印任务 update_task_status(task, "running", 0) # 更新任务状态为运行中 # 如果模型不同,重新加载模型(注意释放内存) if task['ckpt'] != model_name: # 获取环境配置 model_name = task['ckpt'] opt.config = f'/data/{model_name}.yaml' opt.ckpt = f'/data/{model_name}.ckpt' opt.device = 'cuda' print(f"config: {opt.config}", f"ckpt: {opt.ckpt}", f"device: {opt.device}") config = OmegaConf.load(f"{opt.config}") device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu") # 加载模型(到显存) print(f"加载模型到显存: {model_name}..") model = load_model_from_config(config, f"{opt.ckpt}", device) print(f"加载到显存完成: {model_name}") # 使用指定的模型和配置文件进行推理一组参数 if opt.plms: sampler = PLMSSampler(model, device=device) elif opt.dpm: sampler = DPMSolverSampler(model, device=device) else: sampler = DDIMSampler(model, device=device) # 检查输出目录是否存在 os.makedirs(opt.outdir, exist_ok=True) # 绝对路径 outpath = os.path.join(opt.outdir) # 创建水印编码器 wm = "SDV2" wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) # x batch_size = task['number'] if not opt.from_file: prompt = task['prompt'] assert prompt is not None data = [batch_size * [prompt]] print("data:", data) else: print(f"reading prompts from {opt.from_file}") with open(opt.from_file, "r") as f: data = f.read().splitlines() data = [p for p in data for i in range(opt.repeat)] data = list(chunk(data, batch_size)) print("data:", data) # x sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) sample_count = 0 base_count = len(os.listdir(sample_path)) # x start_code = None if opt.fixed_code: start_code = torch.randn([task['number'], opt.C, opt.H // opt.f, opt.W // opt.f], device=device) # 生成图片 precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext with torch.no_grad(), precision_scope(opt.device), model.ema_scope(): images = [] # 执行指定的任务批次 (row)(task['number']) for n in trange(1, desc="Sampling"): print("Sampling:", data) for prompts in tqdm(data, desc="data"): uc = None if opt.scale != 1.0: uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) shape = [opt.C, opt.H // opt.f, opt.W // opt.f] samples, _ = sampler.sample(S=opt.steps, conditioning=c, batch_size=task['number'], shape=shape, verbose=False, unconditional_guidance_scale=opt.scale, unconditional_conditioning=uc, eta=opt.ddim_eta, x_T=start_code) x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples: print("Sample count:", sample_count) imge_path = os.path.join(sample_path, f"{base_count:05}.png") x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') img = Image.fromarray(x_sample.astype(np.uint8)) img = put_watermark(img, wm_encoder) img.save(imge_path) base_count += 1 sample_count += 1 images.append(imge_path) print("Sample count:", sample_count) update_task_status(task=task, status='done', progress=1, data=images) # 修改任务状态为完成 print("批次任务结束..") #break if __name__ == "__main__": opt = parse_args() main_dev(opt)