diff --git a/train.py b/train.py index f8e9839b..e6fc3eb8 100644 --- a/train.py +++ b/train.py @@ -215,8 +215,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_ torch.save({'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'field': field}, os.path.join(args.log_dir, f'iteration_{iteration}.pth')) if world_size > 1: torch.distributed.barrier() - torch.save(opt.state_dict(), os.path.join(args.log_dir, f'iteration_{iteration}_rank_{rank}_optim.pth')) - if world_size > 1: + torch.save(opt.state_dict(), os.path.join(args.log_dir, f'iteration_{iteration}_rank_{rank}_optim.pth')) torch.distributed.barrier() # lr update