From 8bce39431ec6cdbe7c3db336082f7c4036170f44 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Wed, 23 Jun 2021 03:09:23 +0530 Subject: [PATCH] Use XLA utility API to move data to CPU (Single TPU core) (#8078) --- .../plugins/training_type/single_tpu.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 99abff992e..afc692951c 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -12,17 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import os - -import torch +from typing import Any, Dict from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin -from pytorch_lightning.utilities import _TPU_AVAILABLE -from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE +from pytorch_lightning.utilities.apply_func import apply_to_collection if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm +if _OMEGACONF_AVAILABLE: + from omegaconf import DictConfig, ListConfig, OmegaConf + class SingleTPUPlugin(SingleDevicePlugin): """ Plugin for training on a single TPU device. """ @@ -54,13 +56,20 @@ class SingleTPUPlugin(SingleDevicePlugin): self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() - def on_save(self, checkpoint: dict) -> dict: + def save(self, state_dict: Dict, path: str) -> None: + xm.save(state_dict, path) + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path """ - Move XLA tensors to CPU before saving - Recommended on XLA Guide: - https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors - """ - return move_data_to_device(checkpoint, torch.device("cpu")) + # Related Issue: https://github.com/pytorch/xla/issues/2773 + if _OMEGACONF_AVAILABLE: + checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) + self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) def teardown(self) -> None: # TPU teardown