Fix deadlocks for distributed training for RichProgressBar (#10428)
This commit is contained in:
parent
21eafafcb0
commit
5eeca87e98
|
@ -107,6 +107,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388))
|
||||
|
||||
|
||||
- Fixed deadlocks for distributed training with `RichProgressBar` ([#10428](https://github.com/PyTorchLightning/pytorch-lightning/pull/10428))
|
||||
|
||||
|
||||
- Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409))
|
||||
|
||||
|
||||
|
|
|
@ -129,13 +129,19 @@ if _RICH_AVAILABLE:
|
|||
class MetricsTextColumn(ProgressColumn):
|
||||
"""A column containing text."""
|
||||
|
||||
def __init__(self, trainer, pl_module):
|
||||
def __init__(self, trainer):
|
||||
self._trainer = trainer
|
||||
self._pl_module = pl_module
|
||||
self._tasks = {}
|
||||
self._current_task_id = 0
|
||||
self._metrics = {}
|
||||
super().__init__()
|
||||
|
||||
def update(self, metrics):
|
||||
# Called when metrics are ready to be rendered.
|
||||
# This is to prevent render from causing deadlock issues by requesting metrics
|
||||
# in separate threads.
|
||||
self._metrics = metrics
|
||||
|
||||
def render(self, task) -> Text:
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
|
||||
|
@ -149,14 +155,8 @@ if _RICH_AVAILABLE:
|
|||
if self._trainer.training and task.id != self._current_task_id:
|
||||
return self._tasks[task.id]
|
||||
_text = ""
|
||||
# TODO(@daniellepintz): make this code cleaner
|
||||
progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None)
|
||||
if progress_bar_callback:
|
||||
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
|
||||
else:
|
||||
metrics = self._trainer.progress_bar_metrics
|
||||
|
||||
for k, v in metrics.items():
|
||||
for k, v in self._metrics.items():
|
||||
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
|
||||
return Text(_text, justify="left")
|
||||
|
||||
|
@ -225,9 +225,9 @@ class RichProgressBar(ProgressBarBase):
|
|||
self.progress: Optional[Progress] = None
|
||||
self.val_sanity_progress_bar_id: Optional[int] = None
|
||||
self._reset_progress_bar_ids()
|
||||
self._metric_component = None
|
||||
self._progress_stopped: bool = False
|
||||
self.theme = theme
|
||||
self._console: Console = Console()
|
||||
|
||||
@property
|
||||
def refresh_rate_per_second(self) -> float:
|
||||
|
@ -268,12 +268,15 @@ class RichProgressBar(ProgressBarBase):
|
|||
def predict_description(self) -> str:
|
||||
return "Predicting"
|
||||
|
||||
def _init_progress(self, trainer, pl_module):
|
||||
if self.progress is None or self._progress_stopped:
|
||||
def _init_progress(self, trainer):
|
||||
if self.is_enabled and (self.progress is None or self._progress_stopped):
|
||||
self._reset_progress_bar_ids()
|
||||
self._console: Console = Console()
|
||||
self._console.clear_live()
|
||||
self._metric_component = MetricsTextColumn(trainer)
|
||||
self.progress = CustomProgress(
|
||||
*self.configure_columns(trainer, pl_module),
|
||||
*self.configure_columns(trainer),
|
||||
self._metric_component,
|
||||
refresh_per_second=self.refresh_rate_per_second,
|
||||
disable=self.is_disabled,
|
||||
console=self._console,
|
||||
|
@ -284,19 +287,19 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
super().on_train_start(trainer, pl_module)
|
||||
self._init_progress(trainer, pl_module)
|
||||
self._init_progress(trainer)
|
||||
|
||||
def on_predict_start(self, trainer, pl_module):
|
||||
super().on_predict_start(trainer, pl_module)
|
||||
self._init_progress(trainer, pl_module)
|
||||
self._init_progress(trainer)
|
||||
|
||||
def on_test_start(self, trainer, pl_module):
|
||||
super().on_test_start(trainer, pl_module)
|
||||
self._init_progress(trainer, pl_module)
|
||||
self._init_progress(trainer)
|
||||
|
||||
def on_validation_start(self, trainer, pl_module):
|
||||
super().on_validation_start(trainer, pl_module)
|
||||
self._init_progress(trainer, pl_module)
|
||||
self._init_progress(trainer)
|
||||
|
||||
def __getstate__(self):
|
||||
# can't pickle the rich progress objects
|
||||
|
@ -307,12 +310,11 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
# reset console reference after loading progress
|
||||
self._console = Console()
|
||||
state["_console"] = Console()
|
||||
|
||||
def on_sanity_check_start(self, trainer, pl_module):
|
||||
super().on_sanity_check_start(trainer, pl_module)
|
||||
self._init_progress(trainer, pl_module)
|
||||
self._init_progress(trainer)
|
||||
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
|
||||
|
||||
def on_sanity_check_end(self, trainer, pl_module):
|
||||
|
@ -333,10 +335,10 @@ class RichProgressBar(ProgressBarBase):
|
|||
train_description = self._get_train_description(trainer.current_epoch)
|
||||
if self.main_progress_bar_id is not None and self._leave:
|
||||
self._stop_progress()
|
||||
self._init_progress(trainer, pl_module)
|
||||
self._init_progress(trainer)
|
||||
if self.main_progress_bar_id is None:
|
||||
self.main_progress_bar_id = self._add_task(total_batches, train_description)
|
||||
else:
|
||||
elif self.progress is not None:
|
||||
self.progress.reset(
|
||||
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
|
||||
)
|
||||
|
@ -377,6 +379,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
|
||||
self._update(self.main_progress_bar_id)
|
||||
self._update_metrics(trainer, pl_module)
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
|
@ -419,6 +422,11 @@ class RichProgressBar(ProgressBarBase):
|
|||
self.test_progress_bar_id: Optional[int] = None
|
||||
self.predict_progress_bar_id: Optional[int] = None
|
||||
|
||||
def _update_metrics(self, trainer, pl_module) -> None:
|
||||
metrics = self.get_metrics(trainer, pl_module)
|
||||
if self._metric_component:
|
||||
self._metric_component.update(metrics)
|
||||
|
||||
def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
|
||||
self._stop_progress()
|
||||
|
||||
|
@ -441,7 +449,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
def test_progress_bar(self) -> Task:
|
||||
return self.progress.tasks[self.test_progress_bar_id]
|
||||
|
||||
def configure_columns(self, trainer, pl_module) -> list:
|
||||
def configure_columns(self, trainer) -> list:
|
||||
return [
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
CustomBarColumn(
|
||||
|
@ -452,5 +460,4 @@ class RichProgressBar(ProgressBarBase):
|
|||
BatchesProcessedColumn(style=self.theme.batch_process),
|
||||
CustomTimeColumn(style=self.theme.time),
|
||||
ProcessingSpeedColumn(style=self.theme.processing_speed),
|
||||
MetricsTextColumn(trainer, pl_module),
|
||||
]
|
||||
|
|
|
@ -150,15 +150,15 @@ def test_rich_progress_bar_configure_columns():
|
|||
custom_column = TextColumn("[progress.description]Testing Rich!")
|
||||
|
||||
class CustomRichProgressBar(RichProgressBar):
|
||||
def configure_columns(self, trainer, pl_module):
|
||||
def configure_columns(self, trainer):
|
||||
return [custom_column]
|
||||
|
||||
progress_bar = CustomRichProgressBar()
|
||||
|
||||
progress_bar._init_progress(Mock(), Mock())
|
||||
progress_bar._init_progress(Mock())
|
||||
|
||||
assert progress_bar.progress.columns[0] == custom_column
|
||||
assert len(progress_bar.progress.columns) == 1
|
||||
assert len(progress_bar.progress.columns) == 2
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
|
|
Loading…
Reference in New Issue