模型名
This commit is contained in:
30
server.py
30
server.py
@@ -225,6 +225,27 @@ def main_dev(opt):
|
|||||||
# 遍历 data 返回dict
|
# 遍历 data 返回dict
|
||||||
for item in data:
|
for item in data:
|
||||||
print(item) # {'model': '768-v-ema', 'prompt': '一只猫', 'watermark': '0'}
|
print(item) # {'model': '768-v-ema', 'prompt': '一只猫', 'watermark': '0'}
|
||||||
|
'''
|
||||||
|
task: {
|
||||||
|
model: 'SD2',
|
||||||
|
ckpt: 'latest',
|
||||||
|
prompt: '猫猫',
|
||||||
|
number: 1,
|
||||||
|
tid: '06ruxroiuo3u',
|
||||||
|
uid: 1234567890,
|
||||||
|
status: 'waiting',
|
||||||
|
createdAt: 1676660844133,
|
||||||
|
remove: '狗狗',
|
||||||
|
w: 512,
|
||||||
|
h: 512,
|
||||||
|
seed: 0,
|
||||||
|
sampler: 'pndm',
|
||||||
|
prompt_guidance: 0.5,
|
||||||
|
quality_details: 25,
|
||||||
|
image: '',
|
||||||
|
data: null
|
||||||
|
}
|
||||||
|
'''
|
||||||
# 设置参数
|
# 设置参数
|
||||||
if 'prompt' in item: opt.prompt = item['prompt'] # 描述
|
if 'prompt' in item: opt.prompt = item['prompt'] # 描述
|
||||||
if 'n_samples' in item: opt.n_samples = item['n_samples'] # 列数
|
if 'n_samples' in item: opt.n_samples = item['n_samples'] # 列数
|
||||||
@@ -232,9 +253,9 @@ def main_dev(opt):
|
|||||||
if 'scale' in item: opt.scale = item['scale'] # 比例
|
if 'scale' in item: opt.scale = item['scale'] # 比例
|
||||||
|
|
||||||
# 如果模型不同,重新加载模型(注意释放内存)
|
# 如果模型不同,重新加载模型(注意释放内存)
|
||||||
if item['model'] != model_name:
|
if item['ckpt'] != model_name:
|
||||||
# 获取环境配置
|
# 获取环境配置
|
||||||
model_name = item['model']
|
model_name = item['ckpt']
|
||||||
opt.config = f'/data/{model_name}.yaml'
|
opt.config = f'/data/{model_name}.yaml'
|
||||||
opt.ckpt = f'/data/{model_name}.ckpt'
|
opt.ckpt = f'/data/{model_name}.ckpt'
|
||||||
opt.device = 'cuda'
|
opt.device = 'cuda'
|
||||||
@@ -242,10 +263,9 @@ def main_dev(opt):
|
|||||||
config = OmegaConf.load(f"{opt.config}")
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
|
device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
|
||||||
# 加载模型(到显存)
|
# 加载模型(到显存)
|
||||||
print(f"load model: {item['model']}..")
|
print(f"加载模型到显存: {model_name}..")
|
||||||
model_name = item['model']
|
|
||||||
model = load_model_from_config(config, f"{opt.ckpt}", device)
|
model = load_model_from_config(config, f"{opt.ckpt}", device)
|
||||||
print(f"model_name: {model_name}")
|
print(f"加载到显存完成: {model_name}")
|
||||||
# 使用指定的模型和配置文件进行推理一组参数
|
# 使用指定的模型和配置文件进行推理一组参数
|
||||||
if opt.plms:
|
if opt.plms:
|
||||||
sampler = PLMSSampler(model, device=device)
|
sampler = PLMSSampler(model, device=device)
|
||||||
|
Reference in New Issue
Block a user