From 7b6d0a842c78a75fafb4079fc5b84b87ed5e335f Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 5 Jul 2021 15:08:38 +0530 Subject: [PATCH] Fix progress bar updates for Pod Training (#8258) * Fix progress bar updates for Pod Training * Fix progress bar updates for Pod Training * Add _pod_progress_bar_force_stdout --- .../plugins/training_type/tpu_spawn.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 2a30ddce23..f9bc1309f1 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -34,6 +34,7 @@ from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.types import STEP_OUTPUT if _TPU_AVAILABLE: import torch_xla.core.xla_env_vars as xenv @@ -282,6 +283,26 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) + def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: + self._pod_progress_bar_force_stdout() + return output + + def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: + self._pod_progress_bar_force_stdout() + return output + + def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: + self._pod_progress_bar_force_stdout() + return output + + def _pod_progress_bar_force_stdout(self) -> None: + # Why is it required? The way `pytorch_xla.distributed` streams logs + # from different vms to the master worker doesn't work well with tqdm + # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140 + # The print statement seems to force tqdm to flush stdout. + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: + print() + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write.