Fix ModelCheckpoint race condition in file existence check (#5155)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
Adrian Wälchli 2021-01-27 16:27:43 +01:00 committed by Jirka Borovec
parent 605c5a8c9a
commit bb7d188318
6 changed files with 31 additions and 9 deletions

View File

@ -193,6 +193,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support custom DataLoader with DDP if they can be re-instantiated ([#5745](https://github.com/PyTorchLightning/pytorch-lightning/pull/5745))
- Fixed a race condition in `ModelCheckpoint` when checking if a checkpoint file exists ([#5144](https://github.com/PyTorchLightning/pytorch-lightning/pull/5144))
## [1.1.6] - 2021-01-26
### Changed

View File

@ -147,6 +147,7 @@ class HorovodAccelerator(Accelerator):
hvd.join()
def broadcast(self, obj, src=0):
self.barrier()
obj = hvd.broadcast_object(obj, src)
return obj

View File

@ -331,6 +331,9 @@ class TPUAccelerator(Accelerator):
mp_queue.put(last_path)
def broadcast(self, obj, src=0):
if self.trainer.tpu_id is not None:
# running on a single core
return obj
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())

View File

@ -501,13 +501,16 @@ class ModelCheckpoint(Callback):
monitor_candidates: Dict[str, Any],
epoch: int,
step: int,
del_filepath: Optional[str] = None
trainer,
del_filepath: Optional[str] = None,
) -> str:
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates)
version = self.STARTING_VERSION
while self._fs.exists(filepath) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version)
version += 1
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
version_cnt = 0
while self.file_exists(filepath, trainer) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
version_cnt += 1
return filepath
def _monitor_candidates(self, trainer):
@ -532,7 +535,7 @@ class ModelCheckpoint(Callback):
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
else:
last_filepath = self._get_metric_interpolated_filepath_name(
monitor_candidates, trainer.current_epoch, trainer.global_step
ckpt_name_metrics, trainer.current_epoch, trainer.global_step, trainer,
)
accelerator_backend = trainer.accelerator_backend
@ -589,7 +592,7 @@ class ModelCheckpoint(Callback):
if isinstance(current, torch.Tensor) and torch.isnan(current):
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))
filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, del_filepath)
filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, trainer, del_filepath)
# save the current score
self.current_score = current
@ -627,3 +630,13 @@ class ModelCheckpoint(Callback):
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
with self._fs.open(filepath, "w") as fp:
yaml.dump(best_k, fp)
def file_exists(self, filepath: Union[str, Path], trainer) -> bool:
"""
Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing
the internal state to diverge between ranks.
"""
exists = self._fs.exists(filepath)
if trainer.accelerator_backend is not None:
exists = trainer.accelerator_backend.broadcast(exists)
return exists

View File

@ -170,7 +170,8 @@ class ModelCheckpointTestInvocations(ModelCheckpoint):
assert self.best_model_score
assert self.on_save_checkpoint_count == self.expected_count
if trainer.is_global_zero:
assert torch.save.call_count == self.expected_count
# twice the calls expected because ddp broadcast also uses torch.save
assert torch.save.call_count == self.expected_count * 2
else:
assert torch.save.call_count == 0

View File

@ -121,6 +121,7 @@ def test_horovod_multi_gpu(tmpdir):
_run_horovod(trainer_options, on_gpu=True)
@pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?")
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")