Fix deadlocks for distributed training for RichProgressBar (#10428)

This commit is contained in:
Kaushik B 2021-11-09 18:30:37 +05:30 committed by GitHub
parent 21eafafcb0
commit 5eeca87e98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 27 deletions

View File

@ -107,6 +107,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388))
- Fixed deadlocks for distributed training with `RichProgressBar` ([#10428](https://github.com/PyTorchLightning/pytorch-lightning/pull/10428))
- Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409))

View File

@ -129,13 +129,19 @@ if _RICH_AVAILABLE:
class MetricsTextColumn(ProgressColumn):
"""A column containing text."""
def __init__(self, trainer, pl_module):
def __init__(self, trainer):
self._trainer = trainer
self._pl_module = pl_module
self._tasks = {}
self._current_task_id = 0
self._metrics = {}
super().__init__()
def update(self, metrics):
# Called when metrics are ready to be rendered.
# This is to prevent render from causing deadlock issues by requesting metrics
# in separate threads.
self._metrics = metrics
def render(self, task) -> Text:
from pytorch_lightning.trainer.states import TrainerFn
@ -149,14 +155,8 @@ if _RICH_AVAILABLE:
if self._trainer.training and task.id != self._current_task_id:
return self._tasks[task.id]
_text = ""
# TODO(@daniellepintz): make this code cleaner
progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None)
if progress_bar_callback:
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
else:
metrics = self._trainer.progress_bar_metrics
for k, v in metrics.items():
for k, v in self._metrics.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
return Text(_text, justify="left")
@ -225,9 +225,9 @@ class RichProgressBar(ProgressBarBase):
self.progress: Optional[Progress] = None
self.val_sanity_progress_bar_id: Optional[int] = None
self._reset_progress_bar_ids()
self._metric_component = None
self._progress_stopped: bool = False
self.theme = theme
self._console: Console = Console()
@property
def refresh_rate_per_second(self) -> float:
@ -268,12 +268,15 @@ class RichProgressBar(ProgressBarBase):
def predict_description(self) -> str:
return "Predicting"
def _init_progress(self, trainer, pl_module):
if self.progress is None or self._progress_stopped:
def _init_progress(self, trainer):
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
self._console: Console = Console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer)
self.progress = CustomProgress(
*self.configure_columns(trainer, pl_module),
*self.configure_columns(trainer),
self._metric_component,
refresh_per_second=self.refresh_rate_per_second,
disable=self.is_disabled,
console=self._console,
@ -284,19 +287,19 @@ class RichProgressBar(ProgressBarBase):
def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
def on_predict_start(self, trainer, pl_module):
super().on_predict_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
def __getstate__(self):
# can't pickle the rich progress objects
@ -307,12 +310,11 @@ class RichProgressBar(ProgressBarBase):
def __setstate__(self, state):
self.__dict__ = state
# reset console reference after loading progress
self._console = Console()
state["_console"] = Console()
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
def on_sanity_check_end(self, trainer, pl_module):
@ -333,10 +335,10 @@ class RichProgressBar(ProgressBarBase):
train_description = self._get_train_description(trainer.current_epoch)
if self.main_progress_bar_id is not None and self._leave:
self._stop_progress()
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
if self.main_progress_bar_id is None:
self.main_progress_bar_id = self._add_task(total_batches, train_description)
else:
elif self.progress is not None:
self.progress.reset(
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
)
@ -377,6 +379,7 @@ class RichProgressBar(ProgressBarBase):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
self._update(self.main_progress_bar_id)
self._update_metrics(trainer, pl_module)
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
@ -419,6 +422,11 @@ class RichProgressBar(ProgressBarBase):
self.test_progress_bar_id: Optional[int] = None
self.predict_progress_bar_id: Optional[int] = None
def _update_metrics(self, trainer, pl_module) -> None:
metrics = self.get_metrics(trainer, pl_module)
if self._metric_component:
self._metric_component.update(metrics)
def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
self._stop_progress()
@ -441,7 +449,7 @@ class RichProgressBar(ProgressBarBase):
def test_progress_bar(self) -> Task:
return self.progress.tasks[self.test_progress_bar_id]
def configure_columns(self, trainer, pl_module) -> list:
def configure_columns(self, trainer) -> list:
return [
TextColumn("[progress.description]{task.description}"),
CustomBarColumn(
@ -452,5 +460,4 @@ class RichProgressBar(ProgressBarBase):
BatchesProcessedColumn(style=self.theme.batch_process),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
MetricsTextColumn(trainer, pl_module),
]

View File

@ -150,15 +150,15 @@ def test_rich_progress_bar_configure_columns():
custom_column = TextColumn("[progress.description]Testing Rich!")
class CustomRichProgressBar(RichProgressBar):
def configure_columns(self, trainer, pl_module):
def configure_columns(self, trainer):
return [custom_column]
progress_bar = CustomRichProgressBar()
progress_bar._init_progress(Mock(), Mock())
progress_bar._init_progress(Mock())
assert progress_bar.progress.columns[0] == custom_column
assert len(progress_bar.progress.columns) == 1
assert len(progress_bar.progress.columns) == 2
@RunIf(rich=True)