diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a0d36d952..6ffdd8085f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -233,6 +233,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Added a barrier in the accelerator `teardown` to synchronize processes before execution finishes ([#6814](https://github.com/PyTorchLightning/pytorch-lightning/pull/6814)) + + - Fixed multi-node DDP sub-process launch by using `local_rank` instead of `global_rank` for main process assertion ([#7061](https://github.com/PyTorchLightning/pytorch-lightning/pull/7061)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 66e92ae006..c1d7878e4e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -140,8 +140,11 @@ class Accelerator: """ This method is called to teardown the training process. It is the right place to release memory and free other ressources. + + By default we add a barrier here to synchronize processes before returning + control back to the caller. """ - pass + self.barrier("teardown") def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """Moves the batch to the correct device. diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e556d5fbf6..b988dd3bec 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1923,3 +1923,32 @@ def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes) trainer.validate(model) trainer.test(model) trainer.predict(model, model.val_dataloader()) + + +class TestDummyModelForCheckpoint(BoringModel): + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('x', loss) + + def validation_epoch_end(self, outputs) -> None: + pass + + +@RunIf(skip_windows=True) +def test_fit_test_synchronization(tmpdir): + """Test that the trainer synchronizes processes before returning control back to the caller. """ + tutils.set_random_master_port() + model = TestDummyModelForCheckpoint() + checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor='x', mode='min', save_top_k=1) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + accelerator='ddp_cpu', + num_processes=2, + callbacks=[checkpoint], + ) + trainer.fit(model) + assert os.path.exists(checkpoint.best_model_path), f'Could not find checkpoint at rank {trainer.global_rank}' + trainer.test()