[TPU] Save checkpoint on global zero only and sync before (#17882)

This commit is contained in:
Carlos Mocholí 2023-06-21 17:49:31 +02:00 committed by GitHub
parent 6b0ec10ab0
commit 287bdebaa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 21 additions and 3 deletions

View File

@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for loading optimizer states from a full-state checkpoint file ([#17747](https://github.com/Lightning-AI/lightning/pull/17747))
- Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882))
- Automatically call `xla_model.mark_step()` after `optimizer.step()` with XLA ([#17883](https://github.com/Lightning-AI/lightning/pull/17883))

View File

@ -230,9 +230,12 @@ class XLAStrategy(ParallelStrategy):
filter: An optional dictionary of the same format as ``state`` mapping keys to callables that return a
boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``).
"""
state = self._convert_stateful_objects_in_state(state, filter=filter or {})
# `xla_model.save` needs to be called on all ranks. It internally checks if the local rank is 0
self.checkpoint_io.save_checkpoint(state, path, storage_options=storage_options)
import torch_xla.core.xla_model as xm
# sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
xm.mark_step()
# save on global rank zero only
super().save_checkpoint(path, state, storage_options=storage_options, filter=filter)
def remove_checkpoint(self, filepath: _PATH) -> None:
"""Remove checkpoint filepath from the filesystem.

View File

@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Updated `LearningRateMonitor` to log monitored values to `trainer.callback_metrics` ([#17626](https://github.com/Lightning-AI/lightning/pull/17626))
- Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882))
### Changed
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))

View File

@ -234,6 +234,16 @@ class XLAStrategy(DDPStrategy):
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
self._pod_progress_bar_force_stdout()
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
) -> None:
import torch_xla.core.xla_model as xm
# sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
xm.mark_step()
# save on global rank zero only
super().save_checkpoint(checkpoint, filepath, storage_options=storage_options)
def remove_checkpoint(self, filepath: _PATH) -> None:
"""Remove checkpoint filepath from the filesystem.