[TPU] Save checkpoint on global zero only and sync before (#17882)
This commit is contained in:
parent
6b0ec10ab0
commit
287bdebaa6
|
@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for loading optimizer states from a full-state checkpoint file ([#17747](https://github.com/Lightning-AI/lightning/pull/17747))
|
||||
|
||||
|
||||
- Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882))
|
||||
|
||||
|
||||
- Automatically call `xla_model.mark_step()` after `optimizer.step()` with XLA ([#17883](https://github.com/Lightning-AI/lightning/pull/17883))
|
||||
|
||||
|
||||
|
|
|
@ -230,9 +230,12 @@ class XLAStrategy(ParallelStrategy):
|
|||
filter: An optional dictionary of the same format as ``state`` mapping keys to callables that return a
|
||||
boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``).
|
||||
"""
|
||||
state = self._convert_stateful_objects_in_state(state, filter=filter or {})
|
||||
# `xla_model.save` needs to be called on all ranks. It internally checks if the local rank is 0
|
||||
self.checkpoint_io.save_checkpoint(state, path, storage_options=storage_options)
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
# sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
|
||||
xm.mark_step()
|
||||
# save on global rank zero only
|
||||
super().save_checkpoint(path, state, storage_options=storage_options, filter=filter)
|
||||
|
||||
def remove_checkpoint(self, filepath: _PATH) -> None:
|
||||
"""Remove checkpoint filepath from the filesystem.
|
||||
|
|
|
@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Updated `LearningRateMonitor` to log monitored values to `trainer.callback_metrics` ([#17626](https://github.com/Lightning-AI/lightning/pull/17626))
|
||||
|
||||
|
||||
- Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882))
|
||||
|
||||
### Changed
|
||||
|
||||
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
|
||||
|
|
|
@ -234,6 +234,16 @@ class XLAStrategy(DDPStrategy):
|
|||
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
|
||||
self._pod_progress_bar_force_stdout()
|
||||
|
||||
def save_checkpoint(
|
||||
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
|
||||
) -> None:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
# sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
|
||||
xm.mark_step()
|
||||
# save on global rank zero only
|
||||
super().save_checkpoint(checkpoint, filepath, storage_options=storage_options)
|
||||
|
||||
def remove_checkpoint(self, filepath: _PATH) -> None:
|
||||
"""Remove checkpoint filepath from the filesystem.
|
||||
|
||||
|
|
Loading…
Reference in New Issue