This commit is contained in:
awaelchli 2024-07-20 15:13:42 +02:00
parent ecf7933696
commit b840cd4ecc
4 changed files with 13 additions and 19 deletions

View File

@ -297,13 +297,14 @@ def _init_dist_connection(
# PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
atexit.register(_destroy_dist_connection)
# On rank=0 let everyone know training is starting
rank_zero_info(
f"{'-' * 100}\n"
f"distributed_backend={torch_distributed_backend}\n"
f"All distributed processes registered. Starting with {world_size} processes\n"
f"{'-' * 100}\n"
)
# On local_rank=0 let everyone know training is starting
if cluster_environment.local_rank() == 0:
log.info(
f"{'-' * 100}\n"
f"Distributed backend: {torch_distributed_backend.upper()}\n"
f"All distributed processes registered. Starting with {world_size} processes\n"
f"{'-' * 100}\n"
)
def _destroy_dist_connection() -> None:

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Optional, Union
from typing_extensions import override
@ -19,6 +20,8 @@ import lightning.pytorch as pl
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
_log = logging.getLogger(__name__)
class ProgressBar(Callback):
r"""The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback` that keeps
@ -173,6 +176,8 @@ class ProgressBar(Callback):
self._trainer = trainer
if not trainer.is_global_zero:
self.disable()
if trainer.node_rank > 0 and trainer.local_rank == 0:
_log.info("The progress bar output will appear on node 0.")
def get_metrics(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from dataclasses import dataclass
from datetime import timedelta
@ -25,7 +24,6 @@ from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
from lightning.pytorch.utilities.types import STEP_OUTPUT
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
_log = logging.getLogger(__name__)
if _RICH_AVAILABLE:
@ -366,10 +364,6 @@ class RichProgressBar(ProgressBar):
if self.progress:
self.progress.refresh()
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
if trainer.node_rank > 0 and trainer.local_rank == 0:
_log.info("The progress bar output will appear on node 0.")
@override
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._init_progress(trainer)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import math
import os
import sys
@ -34,7 +33,6 @@ import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
_log = logging.getLogger(__name__)
_PAD_SIZE = 5
@ -248,10 +246,6 @@ class TQDMProgressBar(ProgressBar):
bar_format=self.BAR_FORMAT,
)
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
if trainer.node_rank > 0 and trainer.local_rank == 0:
_log.info("The progress bar output will appear on node 0.")
@override
def on_sanity_check_start(self, *_: Any) -> None:
self.val_progress_bar = self.init_sanity_tqdm()