合并
This commit is contained in:
		
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -7,6 +7,11 @@ node_modules
 | 
				
			|||||||
.output
 | 
					.output
 | 
				
			||||||
.env
 | 
					.env
 | 
				
			||||||
dist
 | 
					dist
 | 
				
			||||||
 | 
					data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# python
 | 
					# python
 | 
				
			||||||
venv
 | 
					venv
 | 
				
			||||||
 | 
					outputs
 | 
				
			||||||
 | 
					*.egg-info
 | 
				
			||||||
 | 
					*.egg
 | 
				
			||||||
 | 
					*.pyc
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										363
									
								
								server.py
									
									
									
									
									
								
							
							
						
						
									
										363
									
								
								server.py
									
									
									
									
									
								
							@@ -213,209 +213,186 @@ def put_watermark(img, wm_encoder=None):
 | 
				
			|||||||
import time
 | 
					import time
 | 
				
			||||||
import requests
 | 
					import requests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 获取model, 如果和之前的model不一样,重新加载
 | 
					 | 
				
			||||||
def get_model(model_name):
 | 
					 | 
				
			||||||
    global model
 | 
					 | 
				
			||||||
    global config
 | 
					 | 
				
			||||||
    global device
 | 
					 | 
				
			||||||
    if model_name != model_name:
 | 
					 | 
				
			||||||
        config = OmegaConf.load(f"{opt.config}")
 | 
					 | 
				
			||||||
        device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
 | 
					 | 
				
			||||||
        model = load_model_from_config(config, f"{opt.ckpt}", device)
 | 
					 | 
				
			||||||
    return model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# 使用指定的模型和配置文件进行推理一组参数
 | 
					 | 
				
			||||||
def drawing(model_name):
 | 
					 | 
				
			||||||
    model = get_model(model_name)
 | 
					 | 
				
			||||||
    if opt.plms:
 | 
					 | 
				
			||||||
        sampler = PLMSSampler(model, device=device)
 | 
					 | 
				
			||||||
    elif opt.dpm:
 | 
					 | 
				
			||||||
        sampler = DPMSolverSampler(model, device=device)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        sampler = DDIMSampler(model, device=device)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def main_dev(opt):
 | 
					def main_dev(opt):
 | 
				
			||||||
 | 
					    model_name = ''   # 默认模型
 | 
				
			||||||
 | 
					    model      = None # 默认模型
 | 
				
			||||||
 | 
					    config     = None # 默认配置
 | 
				
			||||||
 | 
					    device     = None # 默认设备
 | 
				
			||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        time.sleep(1) # 延时1s执行, 避免cpu占用过高
 | 
					        time.sleep(2)                                                   # 延时1s执行, 避免cpu占用过高
 | 
				
			||||||
        # 从局域网中获取一组参数
 | 
					        data = requests.get("http://localhost:3000/api/drawing").json() # 从局域网中获取一组参数
 | 
				
			||||||
        request = requests.get("http://localhost:3000/api/drawing")
 | 
					        print(data) # [{'model': '768-v-ema', 'prompt': '一只猫', 'watermark': '0'}, {'model': '768-v-ema', 'prompt': '一只狗', 'watermark': '0'}]
 | 
				
			||||||
        if request.status_code == 200:
 | 
					        # 遍历 data 返回dict
 | 
				
			||||||
            data = request.json()
 | 
					        for item in data:
 | 
				
			||||||
            print("data: ", data)
 | 
					            print(item) # {'model': '768-v-ema', 'prompt': '一只猫', 'watermark': '0'}
 | 
				
			||||||
        #drawing("model_name")
 | 
					            # 设置参数
 | 
				
			||||||
 | 
					            if 'prompt'    in item: opt.prompt    = item['prompt']    # 描述
 | 
				
			||||||
 | 
					            if 'n_samples' in item: opt.n_samples = item['n_samples'] # 列数
 | 
				
			||||||
 | 
					            if 'n_rows'    in item: opt.n_rows    = item['n_rows']    # 行数
 | 
				
			||||||
 | 
					            if 'scale'     in item: opt.scale     = item['scale']     # 比例
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main(opt):
 | 
					            # 如果模型不同,重新加载模型(注意释放内存)
 | 
				
			||||||
    seed_everything(opt.seed)
 | 
					            if item['model'] != model_name:
 | 
				
			||||||
 | 
					                # 获取环境配置
 | 
				
			||||||
    config = OmegaConf.load(f"{opt.config}")
 | 
					                model_name = item['model']
 | 
				
			||||||
    device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
 | 
					                opt.config = f'/data/{model_name}.yaml'
 | 
				
			||||||
    model = load_model_from_config(config, f"{opt.ckpt}", device)
 | 
					                opt.ckpt   = f'/data/{model_name}.ckpt'
 | 
				
			||||||
 | 
					                opt.device = 'cuda'
 | 
				
			||||||
    if opt.plms:
 | 
					                print(f"config: {opt.config}", f"ckpt: {opt.ckpt}", f"device: {opt.device}")
 | 
				
			||||||
        sampler = PLMSSampler(model, device=device)
 | 
					                config = OmegaConf.load(f"{opt.config}")
 | 
				
			||||||
    elif opt.dpm:
 | 
					                device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
 | 
				
			||||||
        sampler = DPMSolverSampler(model, device=device)
 | 
					                # 加载模型(到显存)
 | 
				
			||||||
    else:
 | 
					                print(f"load model: {item['model']}..")
 | 
				
			||||||
        sampler = DDIMSampler(model, device=device)
 | 
					                model_name = item['model']
 | 
				
			||||||
 | 
					                model = load_model_from_config(config, f"{opt.ckpt}", device)
 | 
				
			||||||
    os.makedirs(opt.outdir, exist_ok=True)
 | 
					                print(f"model_name: {model_name}")
 | 
				
			||||||
    outpath = opt.outdir
 | 
					            # 使用指定的模型和配置文件进行推理一组参数
 | 
				
			||||||
 | 
					            if opt.plms:
 | 
				
			||||||
    print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
 | 
					                sampler = PLMSSampler(model, device=device)
 | 
				
			||||||
    wm = "SDV2"
 | 
					            elif opt.dpm:
 | 
				
			||||||
    wm_encoder = WatermarkEncoder()
 | 
					                sampler = DPMSolverSampler(model, device=device)
 | 
				
			||||||
    wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
 | 
					            else:
 | 
				
			||||||
 | 
					                sampler = DDIMSampler(model, device=device)
 | 
				
			||||||
    batch_size = opt.n_samples
 | 
					            # 检查输出目录是否存在
 | 
				
			||||||
    n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
 | 
					            os.makedirs(opt.outdir, exist_ok=True)
 | 
				
			||||||
    if not opt.from_file:
 | 
					            outpath = opt.outdir
 | 
				
			||||||
        prompt = opt.prompt
 | 
					            # 创建水印编码器
 | 
				
			||||||
        assert prompt is not None
 | 
					            wm = "SDV2"
 | 
				
			||||||
        data = [batch_size * [prompt]]
 | 
					            wm_encoder = WatermarkEncoder()
 | 
				
			||||||
 | 
					            wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
 | 
				
			||||||
    else:
 | 
					            # x
 | 
				
			||||||
        print(f"reading prompts from {opt.from_file}")
 | 
					            batch_size = opt.n_samples
 | 
				
			||||||
        with open(opt.from_file, "r") as f:
 | 
					            n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
 | 
				
			||||||
            data = f.read().splitlines()
 | 
					            if not opt.from_file:
 | 
				
			||||||
            data = [p for p in data for i in range(opt.repeat)]
 | 
					                prompt = opt.prompt
 | 
				
			||||||
            data = list(chunk(data, batch_size))
 | 
					                assert prompt is not None
 | 
				
			||||||
 | 
					                data = [batch_size * [prompt]]
 | 
				
			||||||
    sample_path = os.path.join(outpath, "samples")
 | 
					            else:
 | 
				
			||||||
    os.makedirs(sample_path, exist_ok=True)
 | 
					                print(f"reading prompts from {opt.from_file}")
 | 
				
			||||||
    sample_count = 0
 | 
					                with open(opt.from_file, "r") as f:
 | 
				
			||||||
    base_count = len(os.listdir(sample_path))
 | 
					                    data = f.read().splitlines()
 | 
				
			||||||
    grid_count = len(os.listdir(outpath)) - 1
 | 
					                    data = [p for p in data for i in range(opt.repeat)]
 | 
				
			||||||
 | 
					                    data = list(chunk(data, batch_size))
 | 
				
			||||||
    start_code = None
 | 
					            # x
 | 
				
			||||||
    if opt.fixed_code:
 | 
					            sample_path = os.path.join(outpath, "samples")
 | 
				
			||||||
        start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
 | 
					            os.makedirs(sample_path, exist_ok=True)
 | 
				
			||||||
 | 
					            sample_count = 0
 | 
				
			||||||
    if opt.torchscript or opt.ipex:
 | 
					            base_count = len(os.listdir(sample_path))
 | 
				
			||||||
        transformer = model.cond_stage_model.model
 | 
					            grid_count = len(os.listdir(outpath)) - 1
 | 
				
			||||||
        unet = model.model.diffusion_model
 | 
					            # x
 | 
				
			||||||
        decoder = model.first_stage_model.decoder
 | 
					            start_code = None
 | 
				
			||||||
        additional_context = torch.cpu.amp.autocast() if opt.bf16 else nullcontext()
 | 
					            if opt.fixed_code:
 | 
				
			||||||
        shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
 | 
					                start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
 | 
				
			||||||
 | 
					            # x
 | 
				
			||||||
        if opt.bf16 and not opt.torchscript and not opt.ipex:
 | 
					            if opt.torchscript or opt.ipex:
 | 
				
			||||||
            raise ValueError('Bfloat16 is supported only for torchscript+ipex')
 | 
					                transformer = model.cond_stage_model.model
 | 
				
			||||||
        if opt.bf16 and unet.dtype != torch.bfloat16:
 | 
					                unet = model.model.diffusion_model
 | 
				
			||||||
            raise ValueError("Use configs/stable-diffusion/intel/ configs with bf16 enabled if " +
 | 
					                decoder = model.first_stage_model.decoder
 | 
				
			||||||
                             "you'd like to use bfloat16 with CPU.")
 | 
					                additional_context = torch.cpu.amp.autocast() if opt.bf16 else nullcontext()
 | 
				
			||||||
        if unet.dtype == torch.float16 and device == torch.device("cpu"):
 | 
					                shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
 | 
				
			||||||
            raise ValueError("Use configs/stable-diffusion/intel/ configs for your model if you'd like to run it on CPU.")
 | 
					                if opt.bf16 and not opt.torchscript and not opt.ipex:
 | 
				
			||||||
 | 
					                    raise ValueError('Bfloat16 is supported only for torchscript+ipex')
 | 
				
			||||||
        if opt.ipex:
 | 
					                if opt.bf16 and unet.dtype != torch.bfloat16:
 | 
				
			||||||
            import intel_extension_for_pytorch as ipex
 | 
					                    raise ValueError("Use configs/stable-diffusion/intel/ configs with bf16 enabled if " +
 | 
				
			||||||
            bf16_dtype = torch.bfloat16 if opt.bf16 else None
 | 
					                                     "you'd like to use bfloat16 with CPU.")
 | 
				
			||||||
            transformer = transformer.to(memory_format=torch.channels_last)
 | 
					                if unet.dtype == torch.float16 and device == torch.device("cpu"):
 | 
				
			||||||
            transformer = ipex.optimize(transformer, level="O1", inplace=True)
 | 
					                    raise ValueError("Use configs/stable-diffusion/intel/ configs for your model if you'd like to run it on CPU.")
 | 
				
			||||||
 | 
					                if opt.ipex:
 | 
				
			||||||
            unet = unet.to(memory_format=torch.channels_last)
 | 
					                    import intel_extension_for_pytorch as ipex
 | 
				
			||||||
            unet = ipex.optimize(unet, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
 | 
					                    bf16_dtype = torch.bfloat16 if opt.bf16 else None
 | 
				
			||||||
 | 
					                    transformer = transformer.to(memory_format=torch.channels_last)
 | 
				
			||||||
            decoder = decoder.to(memory_format=torch.channels_last)
 | 
					                    transformer = ipex.optimize(transformer, level="O1", inplace=True)
 | 
				
			||||||
            decoder = ipex.optimize(decoder, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
 | 
					                    unet = unet.to(memory_format=torch.channels_last)
 | 
				
			||||||
 | 
					                    unet = ipex.optimize(unet, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
 | 
				
			||||||
        if opt.torchscript:
 | 
					                    decoder = decoder.to(memory_format=torch.channels_last)
 | 
				
			||||||
            with torch.no_grad(), additional_context:
 | 
					                    decoder = ipex.optimize(decoder, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
 | 
				
			||||||
                # get UNET scripted
 | 
					                if opt.torchscript:
 | 
				
			||||||
                if unet.use_checkpoint:
 | 
					                    with torch.no_grad(), additional_context:
 | 
				
			||||||
                    raise ValueError("Gradient checkpoint won't work with tracing. " +
 | 
					                        # get UNET scripted
 | 
				
			||||||
                    "Use configs/stable-diffusion/intel/ configs for your model or disable checkpoint in your config.")
 | 
					                        if unet.use_checkpoint:
 | 
				
			||||||
 | 
					                            raise ValueError("Gradient checkpoint won't work with tracing. " +
 | 
				
			||||||
                img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32)
 | 
					                            "Use configs/stable-diffusion/intel/ configs for your model or disable checkpoint in your config.")
 | 
				
			||||||
                t_in = torch.ones(2, dtype=torch.int64)
 | 
					                        img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32)
 | 
				
			||||||
                context = torch.ones(2, 77, 1024, dtype=torch.float32)
 | 
					                        t_in = torch.ones(2, dtype=torch.int64)
 | 
				
			||||||
                scripted_unet = torch.jit.trace(unet, (img_in, t_in, context))
 | 
					                        context = torch.ones(2, 77, 1024, dtype=torch.float32)
 | 
				
			||||||
                scripted_unet = torch.jit.optimize_for_inference(scripted_unet)
 | 
					                        scripted_unet = torch.jit.trace(unet, (img_in, t_in, context))
 | 
				
			||||||
                print(type(scripted_unet))
 | 
					                        scripted_unet = torch.jit.optimize_for_inference(scripted_unet)
 | 
				
			||||||
                model.model.scripted_diffusion_model = scripted_unet
 | 
					                        print(type(scripted_unet))
 | 
				
			||||||
 | 
					                        model.model.scripted_diffusion_model = scripted_unet
 | 
				
			||||||
                # get Decoder for first stage model scripted
 | 
					                        # get Decoder for first stage model scripted
 | 
				
			||||||
                samples_ddim = torch.ones(1, 4, 96, 96, dtype=torch.float32)
 | 
					                        samples_ddim = torch.ones(1, 4, 96, 96, dtype=torch.float32)
 | 
				
			||||||
                scripted_decoder = torch.jit.trace(decoder, (samples_ddim))
 | 
					                        scripted_decoder = torch.jit.trace(decoder, (samples_ddim))
 | 
				
			||||||
                scripted_decoder = torch.jit.optimize_for_inference(scripted_decoder)
 | 
					                        scripted_decoder = torch.jit.optimize_for_inference(scripted_decoder)
 | 
				
			||||||
                print(type(scripted_decoder))
 | 
					                        print(type(scripted_decoder))
 | 
				
			||||||
                model.first_stage_model.decoder = scripted_decoder
 | 
					                        model.first_stage_model.decoder = scripted_decoder
 | 
				
			||||||
 | 
					                prompts = data[0]
 | 
				
			||||||
        prompts = data[0]
 | 
					                print("Running a forward pass to initialize optimizations")
 | 
				
			||||||
        print("Running a forward pass to initialize optimizations")
 | 
					                uc = None
 | 
				
			||||||
        uc = None
 | 
					                if opt.scale != 1.0:
 | 
				
			||||||
        if opt.scale != 1.0:
 | 
					                    uc = model.get_learned_conditioning(batch_size * [""])
 | 
				
			||||||
            uc = model.get_learned_conditioning(batch_size * [""])
 | 
					                if isinstance(prompts, tuple):
 | 
				
			||||||
        if isinstance(prompts, tuple):
 | 
					                    prompts = list(prompts)
 | 
				
			||||||
            prompts = list(prompts)
 | 
					                with torch.no_grad(), additional_context:
 | 
				
			||||||
 | 
					                    for _ in range(3):
 | 
				
			||||||
        with torch.no_grad(), additional_context:
 | 
					                        c = model.get_learned_conditioning(prompts)
 | 
				
			||||||
            for _ in range(3):
 | 
					                    samples_ddim, _ = sampler.sample(S=5,
 | 
				
			||||||
                c = model.get_learned_conditioning(prompts)
 | 
					 | 
				
			||||||
            samples_ddim, _ = sampler.sample(S=5,
 | 
					 | 
				
			||||||
                                             conditioning=c,
 | 
					 | 
				
			||||||
                                             batch_size=batch_size,
 | 
					 | 
				
			||||||
                                             shape=shape,
 | 
					 | 
				
			||||||
                                             verbose=False,
 | 
					 | 
				
			||||||
                                             unconditional_guidance_scale=opt.scale,
 | 
					 | 
				
			||||||
                                             unconditional_conditioning=uc,
 | 
					 | 
				
			||||||
                                             eta=opt.ddim_eta,
 | 
					 | 
				
			||||||
                                             x_T=start_code)
 | 
					 | 
				
			||||||
            print("Running a forward pass for decoder")
 | 
					 | 
				
			||||||
            for _ in range(3):
 | 
					 | 
				
			||||||
                x_samples_ddim = model.decode_first_stage(samples_ddim)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext
 | 
					 | 
				
			||||||
    with torch.no_grad(), \
 | 
					 | 
				
			||||||
        precision_scope(opt.device), \
 | 
					 | 
				
			||||||
        model.ema_scope():
 | 
					 | 
				
			||||||
            all_samples = list()
 | 
					 | 
				
			||||||
            for n in trange(opt.n_iter, desc="Sampling"):
 | 
					 | 
				
			||||||
                for prompts in tqdm(data, desc="data"):
 | 
					 | 
				
			||||||
                    uc = None
 | 
					 | 
				
			||||||
                    if opt.scale != 1.0:
 | 
					 | 
				
			||||||
                        uc = model.get_learned_conditioning(batch_size * [""])
 | 
					 | 
				
			||||||
                    if isinstance(prompts, tuple):
 | 
					 | 
				
			||||||
                        prompts = list(prompts)
 | 
					 | 
				
			||||||
                    c = model.get_learned_conditioning(prompts)
 | 
					 | 
				
			||||||
                    shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
 | 
					 | 
				
			||||||
                    samples, _ = sampler.sample(S=opt.steps,
 | 
					 | 
				
			||||||
                                                     conditioning=c,
 | 
					                                                     conditioning=c,
 | 
				
			||||||
                                                     batch_size=opt.n_samples,
 | 
					                                                     batch_size=batch_size,
 | 
				
			||||||
                                                     shape=shape,
 | 
					                                                     shape=shape,
 | 
				
			||||||
                                                     verbose=False,
 | 
					                                                     verbose=False,
 | 
				
			||||||
                                                     unconditional_guidance_scale=opt.scale,
 | 
					                                                     unconditional_guidance_scale=opt.scale,
 | 
				
			||||||
                                                     unconditional_conditioning=uc,
 | 
					                                                     unconditional_conditioning=uc,
 | 
				
			||||||
                                                     eta=opt.ddim_eta,
 | 
					                                                     eta=opt.ddim_eta,
 | 
				
			||||||
                                                     x_T=start_code)
 | 
					                                                     x_T=start_code)
 | 
				
			||||||
 | 
					                    print("Running a forward pass for decoder")
 | 
				
			||||||
                    x_samples = model.decode_first_stage(samples)
 | 
					                    for _ in range(3):
 | 
				
			||||||
                    x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
 | 
					                        x_samples_ddim = model.decode_first_stage(samples_ddim)
 | 
				
			||||||
 | 
					            precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext
 | 
				
			||||||
                    for x_sample in x_samples:
 | 
					            with torch.no_grad(), \
 | 
				
			||||||
                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
 | 
					                precision_scope(opt.device), \
 | 
				
			||||||
                        img = Image.fromarray(x_sample.astype(np.uint8))
 | 
					                model.ema_scope():
 | 
				
			||||||
                        img = put_watermark(img, wm_encoder)
 | 
					                    all_samples = list()
 | 
				
			||||||
                        img.save(os.path.join(sample_path, f"{base_count:05}.png"))
 | 
					                    for n in trange(opt.n_iter, desc="Sampling"):
 | 
				
			||||||
                        base_count += 1
 | 
					                        for prompts in tqdm(data, desc="data"):
 | 
				
			||||||
                        sample_count += 1
 | 
					                            uc = None
 | 
				
			||||||
 | 
					                            if opt.scale != 1.0:
 | 
				
			||||||
                    all_samples.append(x_samples)
 | 
					                                uc = model.get_learned_conditioning(batch_size * [""])
 | 
				
			||||||
 | 
					                            if isinstance(prompts, tuple):
 | 
				
			||||||
            # additionally, save as grid
 | 
					                                prompts = list(prompts)
 | 
				
			||||||
            grid = torch.stack(all_samples, 0)
 | 
					                            c = model.get_learned_conditioning(prompts)
 | 
				
			||||||
            grid = rearrange(grid, 'n b c h w -> (n b) c h w')
 | 
					                            shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
 | 
				
			||||||
            grid = make_grid(grid, nrow=n_rows)
 | 
					                            samples, _ = sampler.sample(S=opt.steps,
 | 
				
			||||||
 | 
					                                                             conditioning=c,
 | 
				
			||||||
            # to image
 | 
					                                                             batch_size=opt.n_samples,
 | 
				
			||||||
            grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
 | 
					                                                             shape=shape,
 | 
				
			||||||
            grid = Image.fromarray(grid.astype(np.uint8))
 | 
					                                                             verbose=False,
 | 
				
			||||||
            grid = put_watermark(grid, wm_encoder)
 | 
					                                                             unconditional_guidance_scale=opt.scale,
 | 
				
			||||||
            grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
 | 
					                                                             unconditional_conditioning=uc,
 | 
				
			||||||
            grid_count += 1
 | 
					                                                             eta=opt.ddim_eta,
 | 
				
			||||||
 | 
					                                                             x_T=start_code)
 | 
				
			||||||
    print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
 | 
					                            x_samples = model.decode_first_stage(samples)
 | 
				
			||||||
          f" \nEnjoy.")
 | 
					                            x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
 | 
				
			||||||
 | 
					                            for x_sample in x_samples:
 | 
				
			||||||
 | 
					                                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
 | 
				
			||||||
 | 
					                                img = Image.fromarray(x_sample.astype(np.uint8))
 | 
				
			||||||
 | 
					                                img = put_watermark(img, wm_encoder)
 | 
				
			||||||
 | 
					                                img.save(os.path.join(sample_path, f"{base_count:05}.png"))
 | 
				
			||||||
 | 
					                                base_count += 1
 | 
				
			||||||
 | 
					                                sample_count += 1
 | 
				
			||||||
 | 
					                            all_samples.append(x_samples)
 | 
				
			||||||
 | 
					                    # additionally, save as grid
 | 
				
			||||||
 | 
					                    grid = torch.stack(all_samples, 0)
 | 
				
			||||||
 | 
					                    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
 | 
				
			||||||
 | 
					                    grid = make_grid(grid, nrow=n_rows)
 | 
				
			||||||
 | 
					                    # to image
 | 
				
			||||||
 | 
					                    grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
 | 
				
			||||||
 | 
					                    grid = Image.fromarray(grid.astype(np.uint8))
 | 
				
			||||||
 | 
					                    grid = put_watermark(grid, wm_encoder)
 | 
				
			||||||
 | 
					                    grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
 | 
				
			||||||
 | 
					                    grid_count += 1
 | 
				
			||||||
 | 
					            print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
 | 
				
			||||||
 | 
					                  f" \nEnjoy.")
 | 
				
			||||||
 | 
					        break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    opt = parse_args()
 | 
					    opt = parse_args()
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user