Fix progress bar updates for Pod Training (#8258)
* Fix progress bar updates for Pod Training * Fix progress bar updates for Pod Training * Add _pod_progress_bar_force_stdout
This commit is contained in:
parent
ea5cfd2005
commit
7b6d0a842c
|
@ -34,6 +34,7 @@ from pytorch_lightning.utilities.data import has_len
|
|||
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_env_vars as xenv
|
||||
|
@ -282,6 +283,26 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
def predict_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
|
||||
self._pod_progress_bar_force_stdout()
|
||||
return output
|
||||
|
||||
def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
|
||||
self._pod_progress_bar_force_stdout()
|
||||
return output
|
||||
|
||||
def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
|
||||
self._pod_progress_bar_force_stdout()
|
||||
return output
|
||||
|
||||
def _pod_progress_bar_force_stdout(self) -> None:
|
||||
# Why is it required? The way `pytorch_xla.distributed` streams logs
|
||||
# from different vms to the master worker doesn't work well with tqdm
|
||||
# Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140
|
||||
# The print statement seems to force tqdm to flush stdout.
|
||||
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
|
||||
print()
|
||||
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
|
|
Loading…
Reference in New Issue