Files
drawing/server.py
2023-03-01 02:04:56 +08:00

368 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
import requests
import argparse
import os
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
from imwatermark import WatermarkEncoder
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
torch.set_grad_enabled(False)
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if device == torch.device("cuda"):
model.cuda()
elif device == torch.device("cpu"):
model.cpu()
model.cond_stage_model.device = "cpu"
else:
raise ValueError(f"Incorrect device name. Received: {device}")
model.eval()
return model
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a professional photograph of an astronaut riding a triceratops",
help="the prompt to render"
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
)
parser.add_argument(
"--steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--plms",
action='store_true',
help="use plms sampling",
)
parser.add_argument(
"--dpm",
action='store_true',
help="use DPM (2) sampler",
)
parser.add_argument(
"--fixed_code",
action='store_true',
help="if enabled, uses the same starting code across all samples ",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=3,
help="sample this often",
)
parser.add_argument(
"--H",
type=int,
default=512,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=512,
help="image width, in pixel space",
)
parser.add_argument(
"--C",
type=int,
default=4,
help="latent channels",
)
parser.add_argument(
"--f",
type=int,
default=8,
help="downsampling factor, most often 8 or 16",
)
parser.add_argument(
"--n_samples",
type=int,
default=3,
help="how many samples to produce for each given prompt. A.k.a batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=9.0,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file, separated by newlines",
)
parser.add_argument(
"--config",
type=str,
default="configs/stable-diffusion/v2-inference.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
help="path to checkpoint of model",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help="repeat each prompt in file this often",
)
parser.add_argument(
"--device",
type=str,
help="Device on which Stable Diffusion will be run",
choices=["cpu", "cuda"],
default="cpu"
)
parser.add_argument(
"--torchscript",
action='store_true',
help="Use TorchScript",
)
parser.add_argument(
"--ipex",
action='store_true',
help="Use Intel® Extension for PyTorch*",
)
parser.add_argument(
"--bf16",
action='store_true',
help="Use bfloat16",
)
opt = parser.parse_args()
return opt
import sys
import signal
def quit(signum, frame):
print('stop fusion')
sys.exit()
def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
img = wm_encoder.encode(img, 'dwtDct')
img = Image.fromarray(img[:, :, ::-1])
return img
# 对任务状态的修改
def update_task_status(task: dict, status: str, progress: int, data: list = []):
task["status"] = status
task["progress"] = progress
task["data"] = data
requests.put(f"http://localhost:3000/api/drawing/{task['id']}", json=task)
# 从局域网中获取一组任务(如果列表为空等待2s后重新获取)
def get_tasks(tasks:list=[]):
signal.signal(signal.SIGINT, quit)
signal.signal(signal.SIGTERM, quit)
while len(tasks) == 0:
try:
tasks = requests.get("http://localhost:3000/api/drawing?status=waiting").json()
if len(tasks) == 0: time.sleep(2)
except:
# 打印当前时间
print("get tasks error", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
time.sleep(2)
return tasks
def main_dev(opt):
signal.signal(signal.SIGINT, quit)
signal.signal(signal.SIGTERM, quit)
model_name = '' # 默认模型
model = None # 默认模型
config = None # 默认配置
device = None # 默认设备
while True:
for task in get_tasks(): # 遍历 tasks 返回 dict
print('task:', task) # 打印任务
# 如果模型不同,重新加载模型(注意释放内存)
if task['ckpt'] != model_name:
# 修改状态为加载模型
update_task_status(task, "init", 0)
# 获取环境配置
model_name = task['ckpt']
opt.config = f'/data/{model_name}.yaml'
opt.ckpt = f'/data/{model_name}.ckpt'
opt.device = 'cuda'
print(f"config: {opt.config}", f"ckpt: {opt.ckpt}", f"device: {opt.device}")
config = OmegaConf.load(f"{opt.config}")
device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
# 加载模型(到显存)
print(f"加载模型到显存: {model_name}..")
model = load_model_from_config(config, f"{opt.ckpt}", device)
print(f"加载到显存完成: {model_name}")
# 更新任务状态为运行中
update_task_status(task, "running", 0.5)
# 使用指定的模型和配置文件进行推理一组参数
if opt.plms:
sampler = PLMSSampler(model, device=device)
elif opt.dpm:
sampler = DPMSolverSampler(model, device=device)
else:
sampler = DDIMSampler(model, device=device)
# 检查输出目录是否存在
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
# 创建水印编码器
wm = "SDV2"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
# 生成图片数
batch_size = task['number']
if not opt.from_file:
prompt = task['prompt']
assert prompt is not None
data = [batch_size * [prompt]]
print("data:", data)
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = [p for p in data for i in range(opt.repeat)]
data = list(chunk(data, batch_size))
print("data:", data)
# x
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
sample_count = 0
base_count = len(os.listdir(sample_path))
# x
start_code = None
if task['seed']: #if opt.fixed_code:
start_code = torch.randn([task['seed'], opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
# 更新进度
update_task_status(task, "running", 0.8)
# 生成图片
precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext
with torch.no_grad(), precision_scope(opt.device), model.ema_scope():
images = []
# 执行指定的任务批次 (row)(task['number'])
for n in trange(1, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
update_task_status(task=task, status='diffusing', progress=0.9) # 修改任务状态
samples, _ = sampler.sample(S=opt.steps,
conditioning=c,
batch_size=task['number'],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
update_task_status(task=task, status='build', progress=0.8) # 修改任务状态
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples:
imge_path = os.path.join(sample_path, f"{base_count:05}.png")
imge_path = os.path.abspath(imge_path) # 转换为绝对路径
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(x_sample.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(imge_path)
base_count += 1
sample_count += 1
images.append(imge_path)
update_task_status(task=task, status='done', progress=1, data=images) # 修改任务状态为完成
print("批次任务结束..")
#break
if __name__ == "__main__":
opt = parse_args()
main_dev(opt)