diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 88f9357244..27ec226acc 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for loading optimizer states from a full-state checkpoint file ([#17747](https://github.com/Lightning-AI/lightning/pull/17747)) +- Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882)) + + - Automatically call `xla_model.mark_step()` after `optimizer.step()` with XLA ([#17883](https://github.com/Lightning-AI/lightning/pull/17883)) diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index 91976f30d4..a2b27a78c9 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -230,9 +230,12 @@ class XLAStrategy(ParallelStrategy): filter: An optional dictionary of the same format as ``state`` mapping keys to callables that return a boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``). """ - state = self._convert_stateful_objects_in_state(state, filter=filter or {}) - # `xla_model.save` needs to be called on all ranks. It internally checks if the local rank is 0 - self.checkpoint_io.save_checkpoint(state, path, storage_options=storage_options) + import torch_xla.core.xla_model as xm + + # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs + xm.mark_step() + # save on global rank zero only + super().save_checkpoint(path, state, storage_options=storage_options, filter=filter) def remove_checkpoint(self, filepath: _PATH) -> None: """Remove checkpoint filepath from the filesystem. diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index aa91b7076d..d0670125d4 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated `LearningRateMonitor` to log monitored values to `trainer.callback_metrics` ([#17626](https://github.com/Lightning-AI/lightning/pull/17626)) +- Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882)) + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 5ea1d4bf92..4d3788b20b 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -234,6 +234,16 @@ class XLAStrategy(DDPStrategy): def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: self._pod_progress_bar_force_stdout() + def save_checkpoint( + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + ) -> None: + import torch_xla.core.xla_model as xm + + # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs + xm.mark_step() + # save on global rank zero only + super().save_checkpoint(checkpoint, filepath, storage_options=storage_options) + def remove_checkpoint(self, filepath: _PATH) -> None: """Remove checkpoint filepath from the filesystem.