This commit is contained in:
2023-02-19 05:06:23 +08:00
parent eb0c491577
commit 4d6ca94eff
2 changed files with 12 additions and 6 deletions

View File

@@ -229,7 +229,8 @@ def get_tasks(tasks:list=[]):
tasks = requests.get("http://localhost:3000/api/drawing?status=waiting").json() tasks = requests.get("http://localhost:3000/api/drawing?status=waiting").json()
if len(tasks) == 0: time.sleep(2) if len(tasks) == 0: time.sleep(2)
except: except:
print("get tasks error") # 打印当前时间
print("get tasks error", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
time.sleep(2) time.sleep(2)
return tasks return tasks
@@ -241,10 +242,11 @@ def main_dev(opt):
while True: while True:
for task in get_tasks(): # 遍历 tasks 返回 dict for task in get_tasks(): # 遍历 tasks 返回 dict
print('task:', task) # 打印任务 print('task:', task) # 打印任务
update_task_status(task, "running", 0) # 更新任务状态为运行中
# 如果模型不同,重新加载模型(注意释放内存) # 如果模型不同,重新加载模型(注意释放内存)
if task['ckpt'] != model_name: if task['ckpt'] != model_name:
# 修改状态为加载模型
update_task_status(task, "loading", 0)
# 获取环境配置 # 获取环境配置
model_name = task['ckpt'] model_name = task['ckpt']
opt.config = f'/data/{model_name}.yaml' opt.config = f'/data/{model_name}.yaml'
@@ -257,6 +259,9 @@ def main_dev(opt):
print(f"加载模型到显存: {model_name}..") print(f"加载模型到显存: {model_name}..")
model = load_model_from_config(config, f"{opt.ckpt}", device) model = load_model_from_config(config, f"{opt.ckpt}", device)
print(f"加载到显存完成: {model_name}") print(f"加载到显存完成: {model_name}")
# 更新任务状态为运行中
update_task_status(task, "running", 0)
# 使用指定的模型和配置文件进行推理一组参数 # 使用指定的模型和配置文件进行推理一组参数
if opt.plms: if opt.plms:
@@ -300,13 +305,16 @@ def main_dev(opt):
start_code = None start_code = None
if opt.fixed_code: if opt.fixed_code:
start_code = torch.randn([task['number'], opt.C, opt.H // opt.f, opt.W // opt.f], device=device) start_code = torch.randn([task['number'], opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
# 更新进度
update_task_status(task, "running", 0.1)
# 生成图片 # 生成图片
precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext
with torch.no_grad(), precision_scope(opt.device), model.ema_scope(): with torch.no_grad(), precision_scope(opt.device), model.ema_scope():
images = [] images = []
# 执行指定的任务批次 (row)(task['number']) # 执行指定的任务批次 (row)(task['number'])
for n in trange(1, desc="Sampling"): for n in trange(1, desc="Sampling"):
print("Sampling:", data)
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):
uc = None uc = None
if opt.scale != 1.0: if opt.scale != 1.0:
@@ -327,7 +335,6 @@ def main_dev(opt):
x_samples = model.decode_first_stage(samples) x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples: for x_sample in x_samples:
print("Sample count:", sample_count)
imge_path = os.path.join(sample_path, f"{base_count:05}.png") imge_path = os.path.join(sample_path, f"{base_count:05}.png")
imge_path = os.path.abspath(imge_path) # 转换为绝对路径 imge_path = os.path.abspath(imge_path) # 转换为绝对路径
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
@@ -337,7 +344,6 @@ def main_dev(opt):
base_count += 1 base_count += 1
sample_count += 1 sample_count += 1
images.append(imge_path) images.append(imge_path)
print("Sample count:", sample_count)
update_task_status(task=task, status='done', progress=1, data=images) # 修改任务状态为完成 update_task_status(task=task, status='done', progress=1, data=images) # 修改任务状态为完成
print("批次任务结束..") print("批次任务结束..")
#break #break

View File

@@ -8,7 +8,7 @@ export default defineEventHandler(async event => {
if (!fs.existsSync('outputs')) { if (!fs.existsSync('outputs')) {
// 打印当前执行目录 // 打印当前执行目录
console.log('cwd:', process.cwd()) console.log('cwd:', process.cwd())
path = `../../outputs/txt2img-samples/samples/${event.context.params.id}` path = `../outputs/txt2img-samples/samples/${event.context.params.id}`
} }
console.log('path:', path) console.log('path:', path)
return fs.readFileSync(path) return fs.readFileSync(path)