模型名

This commit is contained in:
2023-02-18 03:19:56 +08:00
parent af09452013
commit 4cdd13c688

View File

@@ -225,6 +225,27 @@ def main_dev(opt):
# 遍历 data 返回dict # 遍历 data 返回dict
for item in data: for item in data:
print(item) # {'model': '768-v-ema', 'prompt': '一只猫', 'watermark': '0'} print(item) # {'model': '768-v-ema', 'prompt': '一只猫', 'watermark': '0'}
'''
task: {
model: 'SD2',
ckpt: 'latest',
prompt: '猫猫',
number: 1,
tid: '06ruxroiuo3u',
uid: 1234567890,
status: 'waiting',
createdAt: 1676660844133,
remove: '狗狗',
w: 512,
h: 512,
seed: 0,
sampler: 'pndm',
prompt_guidance: 0.5,
quality_details: 25,
image: '',
data: null
}
'''
# 设置参数 # 设置参数
if 'prompt' in item: opt.prompt = item['prompt'] # 描述 if 'prompt' in item: opt.prompt = item['prompt'] # 描述
if 'n_samples' in item: opt.n_samples = item['n_samples'] # 列数 if 'n_samples' in item: opt.n_samples = item['n_samples'] # 列数
@@ -232,9 +253,9 @@ def main_dev(opt):
if 'scale' in item: opt.scale = item['scale'] # 比例 if 'scale' in item: opt.scale = item['scale'] # 比例
# 如果模型不同,重新加载模型(注意释放内存) # 如果模型不同,重新加载模型(注意释放内存)
if item['model'] != model_name: if item['ckpt'] != model_name:
# 获取环境配置 # 获取环境配置
model_name = item['model'] model_name = item['ckpt']
opt.config = f'/data/{model_name}.yaml' opt.config = f'/data/{model_name}.yaml'
opt.ckpt = f'/data/{model_name}.ckpt' opt.ckpt = f'/data/{model_name}.ckpt'
opt.device = 'cuda' opt.device = 'cuda'
@@ -242,10 +263,9 @@ def main_dev(opt):
config = OmegaConf.load(f"{opt.config}") config = OmegaConf.load(f"{opt.config}")
device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu") device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
# 加载模型(到显存) # 加载模型(到显存)
print(f"load model: {item['model']}..") print(f"加载模型到显存: {model_name}..")
model_name = item['model']
model = load_model_from_config(config, f"{opt.ckpt}", device) model = load_model_from_config(config, f"{opt.ckpt}", device)
print(f"model_name: {model_name}") print(f"加载到显存完成: {model_name}")
# 使用指定的模型和配置文件进行推理一组参数 # 使用指定的模型和配置文件进行推理一组参数
if opt.plms: if opt.plms:
sampler = PLMSSampler(model, device=device) sampler = PLMSSampler(model, device=device)