From ecf7933696586c0e1141790da86d6bc88c92bcef Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 20 Jul 2024 14:58:12 +0200 Subject: [PATCH] Print info message about progress bar output on node 0 --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 7 +++++++ src/lightning/pytorch/callbacks/progress/tqdm_progress.py | 6 ++++++ 2 files changed, 13 insertions(+) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 5ef45baf96..fa97eac7bc 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import math from dataclasses import dataclass from datetime import timedelta @@ -24,6 +25,8 @@ from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar from lightning.pytorch.utilities.types import STEP_OUTPUT _RICH_AVAILABLE = RequirementCache("rich>=10.2.2") +_log = logging.getLogger(__name__) + if _RICH_AVAILABLE: from rich import get_console, reconfigure @@ -363,6 +366,10 @@ class RichProgressBar(ProgressBar): if self.progress: self.progress.refresh() + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + if trainer.node_rank > 0 and trainer.local_rank == 0: + _log.info("The progress bar output will appear on node 0.") + @override def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._init_progress(trainer) diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index bf9e238a01..c24ee10247 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import logging import math import os import sys @@ -33,6 +34,7 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar from lightning.pytorch.utilities.rank_zero import rank_zero_debug +_log = logging.getLogger(__name__) _PAD_SIZE = 5 @@ -246,6 +248,10 @@ class TQDMProgressBar(ProgressBar): bar_format=self.BAR_FORMAT, ) + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + if trainer.node_rank > 0 and trainer.local_rank == 0: + _log.info("The progress bar output will appear on node 0.") + @override def on_sanity_check_start(self, *_: Any) -> None: self.val_progress_bar = self.init_sanity_tqdm()