From b840cd4eccad1adc7fdae3612783954e8c338c1c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 20 Jul 2024 15:13:42 +0200 Subject: [PATCH] fix --- src/lightning/fabric/utilities/distributed.py | 15 ++++++++------- .../pytorch/callbacks/progress/progress_bar.py | 5 +++++ .../pytorch/callbacks/progress/rich_progress.py | 6 ------ .../pytorch/callbacks/progress/tqdm_progress.py | 6 ------ 4 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 75b2f7c580..7292e34e03 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -297,13 +297,14 @@ def _init_dist_connection( # PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit atexit.register(_destroy_dist_connection) - # On rank=0 let everyone know training is starting - rank_zero_info( - f"{'-' * 100}\n" - f"distributed_backend={torch_distributed_backend}\n" - f"All distributed processes registered. Starting with {world_size} processes\n" - f"{'-' * 100}\n" - ) + # On local_rank=0 let everyone know training is starting + if cluster_environment.local_rank() == 0: + log.info( + f"{'-' * 100}\n" + f"Distributed backend: {torch_distributed_backend.upper()}\n" + f"All distributed processes registered. Starting with {world_size} processes\n" + f"{'-' * 100}\n" + ) def _destroy_dist_connection() -> None: diff --git a/src/lightning/pytorch/callbacks/progress/progress_bar.py b/src/lightning/pytorch/callbacks/progress/progress_bar.py index 785bf65af4..bb71a7d661 100644 --- a/src/lightning/pytorch/callbacks/progress/progress_bar.py +++ b/src/lightning/pytorch/callbacks/progress/progress_bar.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Any, Dict, Optional, Union from typing_extensions import override @@ -19,6 +20,8 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import Callback from lightning.pytorch.utilities.rank_zero import rank_zero_warn +_log = logging.getLogger(__name__) + class ProgressBar(Callback): r"""The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback` that keeps @@ -173,6 +176,8 @@ class ProgressBar(Callback): self._trainer = trainer if not trainer.is_global_zero: self.disable() + if trainer.node_rank > 0 and trainer.local_rank == 0: + _log.info("The progress bar output will appear on node 0.") def get_metrics( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index fa97eac7bc..33bc34eada 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import math from dataclasses import dataclass from datetime import timedelta @@ -25,7 +24,6 @@ from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar from lightning.pytorch.utilities.types import STEP_OUTPUT _RICH_AVAILABLE = RequirementCache("rich>=10.2.2") -_log = logging.getLogger(__name__) if _RICH_AVAILABLE: @@ -366,10 +364,6 @@ class RichProgressBar(ProgressBar): if self.progress: self.progress.refresh() - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: - if trainer.node_rank > 0 and trainer.local_rank == 0: - _log.info("The progress bar output will appear on node 0.") - @override def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._init_progress(trainer) diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index c24ee10247..bf9e238a01 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib -import logging import math import os import sys @@ -34,7 +33,6 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar from lightning.pytorch.utilities.rank_zero import rank_zero_debug -_log = logging.getLogger(__name__) _PAD_SIZE = 5 @@ -248,10 +246,6 @@ class TQDMProgressBar(ProgressBar): bar_format=self.BAR_FORMAT, ) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: - if trainer.node_rank > 0 and trainer.local_rank == 0: - _log.info("The progress bar output will appear on node 0.") - @override def on_sanity_check_start(self, *_: Any) -> None: self.val_progress_bar = self.init_sanity_tqdm()