train: don't save per-worker checkpoints if we're not doing distributed training
Saves disk space
This commit is contained in:
parent
a36f2efb8c
commit
d4b35d7ae6
3
train.py
3
train.py
|
@ -215,8 +215,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_
|
|||
torch.save({'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'field': field}, os.path.join(args.log_dir, f'iteration_{iteration}.pth'))
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.save(opt.state_dict(), os.path.join(args.log_dir, f'iteration_{iteration}_rank_{rank}_optim.pth'))
|
||||
if world_size > 1:
|
||||
torch.save(opt.state_dict(), os.path.join(args.log_dir, f'iteration_{iteration}_rank_{rank}_optim.pth'))
|
||||
torch.distributed.barrier()
|
||||
|
||||
# lr update
|
||||
|
|
Loading…
Reference in New Issue