[fix] Add barrier to accelerator's teardown (#6814)

This commit is contained in:
ananthsub 2021-04-26 02:23:29 -07:00 committed by GitHub
parent 68eac4d948
commit bc3f08b0e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 1 deletions

View File

@ -233,6 +233,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Added a barrier in the accelerator `teardown` to synchronize processes before execution finishes ([#6814](https://github.com/PyTorchLightning/pytorch-lightning/pull/6814))
- Fixed multi-node DDP sub-process launch by using `local_rank` instead of `global_rank` for main process assertion ([#7061](https://github.com/PyTorchLightning/pytorch-lightning/pull/7061))

View File

@ -140,8 +140,11 @@ class Accelerator:
"""
This method is called to teardown the training process.
It is the right place to release memory and free other ressources.
By default we add a barrier here to synchronize processes before returning
control back to the caller.
"""
pass
self.barrier("teardown")
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
"""Moves the batch to the correct device.

View File

@ -1923,3 +1923,32 @@ def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes)
trainer.validate(model)
trainer.test(model)
trainer.predict(model, model.val_dataloader())
class TestDummyModelForCheckpoint(BoringModel):
def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('x', loss)
def validation_epoch_end(self, outputs) -> None:
pass
@RunIf(skip_windows=True)
def test_fit_test_synchronization(tmpdir):
"""Test that the trainer synchronizes processes before returning control back to the caller. """
tutils.set_random_master_port()
model = TestDummyModelForCheckpoint()
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor='x', mode='min', save_top_k=1)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
accelerator='ddp_cpu',
num_processes=2,
callbacks=[checkpoint],
)
trainer.fit(model)
assert os.path.exists(checkpoint.best_model_path), f'Could not find checkpoint at rank {trainer.global_rank}'
trainer.test()