Print info message about progress bar output on node 0

This commit is contained in:
awaelchli 2024-07-20 14:58:12 +02:00
parent 1cd774197d
commit ecf7933696
2 changed files with 13 additions and 0 deletions

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from dataclasses import dataclass
from datetime import timedelta
@ -24,6 +25,8 @@ from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
from lightning.pytorch.utilities.types import STEP_OUTPUT
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
_log = logging.getLogger(__name__)
if _RICH_AVAILABLE:
from rich import get_console, reconfigure
@ -363,6 +366,10 @@ class RichProgressBar(ProgressBar):
if self.progress:
self.progress.refresh()
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
if trainer.node_rank > 0 and trainer.local_rank == 0:
_log.info("The progress bar output will appear on node 0.")
@override
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._init_progress(trainer)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import math
import os
import sys
@ -33,6 +34,7 @@ import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
_log = logging.getLogger(__name__)
_PAD_SIZE = 5
@ -246,6 +248,10 @@ class TQDMProgressBar(ProgressBar):
bar_format=self.BAR_FORMAT,
)
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
if trainer.node_rank > 0 and trainer.local_rank == 0:
_log.info("The progress bar output will appear on node 0.")
@override
def on_sanity_check_start(self, *_: Any) -> None:
self.val_progress_bar = self.init_sanity_tqdm()