Fix ModelCheckpoint race condition in file existence check (#5155)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
parent
605c5a8c9a
commit
bb7d188318
|
@ -193,6 +193,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed support custom DataLoader with DDP if they can be re-instantiated ([#5745](https://github.com/PyTorchLightning/pytorch-lightning/pull/5745))
|
||||
|
||||
|
||||
- Fixed a race condition in `ModelCheckpoint` when checking if a checkpoint file exists ([#5144](https://github.com/PyTorchLightning/pytorch-lightning/pull/5144))
|
||||
|
||||
|
||||
## [1.1.6] - 2021-01-26
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -147,6 +147,7 @@ class HorovodAccelerator(Accelerator):
|
|||
hvd.join()
|
||||
|
||||
def broadcast(self, obj, src=0):
|
||||
self.barrier()
|
||||
obj = hvd.broadcast_object(obj, src)
|
||||
return obj
|
||||
|
||||
|
|
|
@ -331,6 +331,9 @@ class TPUAccelerator(Accelerator):
|
|||
mp_queue.put(last_path)
|
||||
|
||||
def broadcast(self, obj, src=0):
|
||||
if self.trainer.tpu_id is not None:
|
||||
# running on a single core
|
||||
return obj
|
||||
buffer = io.BytesIO()
|
||||
torch.save(obj, buffer)
|
||||
data = bytearray(buffer.getbuffer())
|
||||
|
|
|
@ -501,13 +501,16 @@ class ModelCheckpoint(Callback):
|
|||
monitor_candidates: Dict[str, Any],
|
||||
epoch: int,
|
||||
step: int,
|
||||
del_filepath: Optional[str] = None
|
||||
trainer,
|
||||
del_filepath: Optional[str] = None,
|
||||
) -> str:
|
||||
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates)
|
||||
version = self.STARTING_VERSION
|
||||
while self._fs.exists(filepath) and filepath != del_filepath:
|
||||
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version)
|
||||
version += 1
|
||||
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
|
||||
|
||||
version_cnt = 0
|
||||
while self.file_exists(filepath, trainer) and filepath != del_filepath:
|
||||
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
|
||||
version_cnt += 1
|
||||
|
||||
return filepath
|
||||
|
||||
def _monitor_candidates(self, trainer):
|
||||
|
@ -532,7 +535,7 @@ class ModelCheckpoint(Callback):
|
|||
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
|
||||
else:
|
||||
last_filepath = self._get_metric_interpolated_filepath_name(
|
||||
monitor_candidates, trainer.current_epoch, trainer.global_step
|
||||
ckpt_name_metrics, trainer.current_epoch, trainer.global_step, trainer,
|
||||
)
|
||||
|
||||
accelerator_backend = trainer.accelerator_backend
|
||||
|
@ -589,7 +592,7 @@ class ModelCheckpoint(Callback):
|
|||
if isinstance(current, torch.Tensor) and torch.isnan(current):
|
||||
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))
|
||||
|
||||
filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, del_filepath)
|
||||
filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, trainer, del_filepath)
|
||||
|
||||
# save the current score
|
||||
self.current_score = current
|
||||
|
@ -627,3 +630,13 @@ class ModelCheckpoint(Callback):
|
|||
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
|
||||
with self._fs.open(filepath, "w") as fp:
|
||||
yaml.dump(best_k, fp)
|
||||
|
||||
def file_exists(self, filepath: Union[str, Path], trainer) -> bool:
|
||||
"""
|
||||
Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing
|
||||
the internal state to diverge between ranks.
|
||||
"""
|
||||
exists = self._fs.exists(filepath)
|
||||
if trainer.accelerator_backend is not None:
|
||||
exists = trainer.accelerator_backend.broadcast(exists)
|
||||
return exists
|
||||
|
|
|
@ -170,7 +170,8 @@ class ModelCheckpointTestInvocations(ModelCheckpoint):
|
|||
assert self.best_model_score
|
||||
assert self.on_save_checkpoint_count == self.expected_count
|
||||
if trainer.is_global_zero:
|
||||
assert torch.save.call_count == self.expected_count
|
||||
# twice the calls expected because ddp broadcast also uses torch.save
|
||||
assert torch.save.call_count == self.expected_count * 2
|
||||
else:
|
||||
assert torch.save.call_count == 0
|
||||
|
||||
|
|
|
@ -121,6 +121,7 @@ def test_horovod_multi_gpu(tmpdir):
|
|||
_run_horovod(trainer_options, on_gpu=True)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
|
||||
@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
|
|
Loading…
Reference in New Issue