ckpt load

This commit is contained in:
2023-03-03 04:54:52 +08:00
parent 75274a9015
commit bc6c50964d
2 changed files with 12 additions and 9 deletions

View File

@@ -34,6 +34,7 @@ def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=Fa
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
print('torch: load ckpt done')
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)