[fix] Add barrier to accelerator's teardown (#6814)
This commit is contained in:
parent
68eac4d948
commit
bc3f08b0e3
|
@ -233,6 +233,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Added a barrier in the accelerator `teardown` to synchronize processes before execution finishes ([#6814](https://github.com/PyTorchLightning/pytorch-lightning/pull/6814))
|
||||
|
||||
|
||||
- Fixed multi-node DDP sub-process launch by using `local_rank` instead of `global_rank` for main process assertion ([#7061](https://github.com/PyTorchLightning/pytorch-lightning/pull/7061))
|
||||
|
||||
|
||||
|
|
|
@ -140,8 +140,11 @@ class Accelerator:
|
|||
"""
|
||||
This method is called to teardown the training process.
|
||||
It is the right place to release memory and free other ressources.
|
||||
|
||||
By default we add a barrier here to synchronize processes before returning
|
||||
control back to the caller.
|
||||
"""
|
||||
pass
|
||||
self.barrier("teardown")
|
||||
|
||||
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
|
||||
"""Moves the batch to the correct device.
|
||||
|
|
|
@ -1923,3 +1923,32 @@ def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes)
|
|||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model, model.val_dataloader())
|
||||
|
||||
|
||||
class TestDummyModelForCheckpoint(BoringModel):
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
self.log('x', loss)
|
||||
|
||||
def validation_epoch_end(self, outputs) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
def test_fit_test_synchronization(tmpdir):
|
||||
"""Test that the trainer synchronizes processes before returning control back to the caller. """
|
||||
tutils.set_random_master_port()
|
||||
model = TestDummyModelForCheckpoint()
|
||||
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor='x', mode='min', save_top_k=1)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
accelerator='ddp_cpu',
|
||||
num_processes=2,
|
||||
callbacks=[checkpoint],
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert os.path.exists(checkpoint.best_model_path), f'Could not find checkpoint at rank {trainer.global_rank}'
|
||||
trainer.test()
|
||||
|
|
Loading…
Reference in New Issue