DEBUG
This commit is contained in:
16
server.py
16
server.py
@@ -229,7 +229,8 @@ def get_tasks(tasks:list=[]):
|
||||
tasks = requests.get("http://localhost:3000/api/drawing?status=waiting").json()
|
||||
if len(tasks) == 0: time.sleep(2)
|
||||
except:
|
||||
print("get tasks error")
|
||||
# 打印当前时间
|
||||
print("get tasks error", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
time.sleep(2)
|
||||
return tasks
|
||||
|
||||
@@ -241,10 +242,11 @@ def main_dev(opt):
|
||||
while True:
|
||||
for task in get_tasks(): # 遍历 tasks 返回 dict
|
||||
print('task:', task) # 打印任务
|
||||
update_task_status(task, "running", 0) # 更新任务状态为运行中
|
||||
|
||||
# 如果模型不同,重新加载模型(注意释放内存)
|
||||
if task['ckpt'] != model_name:
|
||||
# 修改状态为加载模型
|
||||
update_task_status(task, "loading", 0)
|
||||
# 获取环境配置
|
||||
model_name = task['ckpt']
|
||||
opt.config = f'/data/{model_name}.yaml'
|
||||
@@ -257,6 +259,9 @@ def main_dev(opt):
|
||||
print(f"加载模型到显存: {model_name}..")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}", device)
|
||||
print(f"加载到显存完成: {model_name}")
|
||||
|
||||
# 更新任务状态为运行中
|
||||
update_task_status(task, "running", 0)
|
||||
|
||||
# 使用指定的模型和配置文件进行推理一组参数
|
||||
if opt.plms:
|
||||
@@ -300,13 +305,16 @@ def main_dev(opt):
|
||||
start_code = None
|
||||
if opt.fixed_code:
|
||||
start_code = torch.randn([task['number'], opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||
|
||||
# 更新进度
|
||||
update_task_status(task, "running", 0.1)
|
||||
|
||||
# 生成图片
|
||||
precision_scope = autocast if opt.precision == "autocast" or opt.bf16 else nullcontext
|
||||
with torch.no_grad(), precision_scope(opt.device), model.ema_scope():
|
||||
images = []
|
||||
# 执行指定的任务批次 (row)(task['number'])
|
||||
for n in trange(1, desc="Sampling"):
|
||||
print("Sampling:", data)
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
@@ -327,7 +335,6 @@ def main_dev(opt):
|
||||
x_samples = model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
for x_sample in x_samples:
|
||||
print("Sample count:", sample_count)
|
||||
imge_path = os.path.join(sample_path, f"{base_count:05}.png")
|
||||
imge_path = os.path.abspath(imge_path) # 转换为绝对路径
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
@@ -337,7 +344,6 @@ def main_dev(opt):
|
||||
base_count += 1
|
||||
sample_count += 1
|
||||
images.append(imge_path)
|
||||
print("Sample count:", sample_count)
|
||||
update_task_status(task=task, status='done', progress=1, data=images) # 修改任务状态为完成
|
||||
print("批次任务结束..")
|
||||
#break
|
||||
|
@@ -8,7 +8,7 @@ export default defineEventHandler(async event => {
|
||||
if (!fs.existsSync('outputs')) {
|
||||
// 打印当前执行目录
|
||||
console.log('cwd:', process.cwd())
|
||||
path = `../../outputs/txt2img-samples/samples/${event.context.params.id}`
|
||||
path = `../outputs/txt2img-samples/samples/${event.context.params.id}`
|
||||
}
|
||||
console.log('path:', path)
|
||||
return fs.readFileSync(path)
|
||||
|
Reference in New Issue
Block a user