From d4b35d7ae68aa14f5a89845944c40549b35d7a87 Mon Sep 17 00:00:00 2001 From: Giovanni Campagna Date: Fri, 1 Mar 2019 10:51:12 -0800 Subject: [PATCH] train: don't save per-worker checkpoints if we're not doing distributed training Saves disk space --- train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/train.py b/train.py index f8e9839b..e6fc3eb8 100644 --- a/train.py +++ b/train.py @@ -215,8 +215,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_ torch.save({'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'field': field}, os.path.join(args.log_dir, f'iteration_{iteration}.pth')) if world_size > 1: torch.distributed.barrier() - torch.save(opt.state_dict(), os.path.join(args.log_dir, f'iteration_{iteration}_rank_{rank}_optim.pth')) - if world_size > 1: + torch.save(opt.state_dict(), os.path.join(args.log_dir, f'iteration_{iteration}_rank_{rank}_optim.pth')) torch.distributed.barrier() # lr update