fix
This commit is contained in:
parent
ecf7933696
commit
b840cd4ecc
|
@ -297,13 +297,14 @@ def _init_dist_connection(
|
|||
# PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
|
||||
atexit.register(_destroy_dist_connection)
|
||||
|
||||
# On rank=0 let everyone know training is starting
|
||||
rank_zero_info(
|
||||
f"{'-' * 100}\n"
|
||||
f"distributed_backend={torch_distributed_backend}\n"
|
||||
f"All distributed processes registered. Starting with {world_size} processes\n"
|
||||
f"{'-' * 100}\n"
|
||||
)
|
||||
# On local_rank=0 let everyone know training is starting
|
||||
if cluster_environment.local_rank() == 0:
|
||||
log.info(
|
||||
f"{'-' * 100}\n"
|
||||
f"Distributed backend: {torch_distributed_backend.upper()}\n"
|
||||
f"All distributed processes registered. Starting with {world_size} processes\n"
|
||||
f"{'-' * 100}\n"
|
||||
)
|
||||
|
||||
|
||||
def _destroy_dist_connection() -> None:
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
@ -19,6 +20,8 @@ import lightning.pytorch as pl
|
|||
from lightning.pytorch.callbacks import Callback
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProgressBar(Callback):
|
||||
r"""The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback` that keeps
|
||||
|
@ -173,6 +176,8 @@ class ProgressBar(Callback):
|
|||
self._trainer = trainer
|
||||
if not trainer.is_global_zero:
|
||||
self.disable()
|
||||
if trainer.node_rank > 0 and trainer.local_rank == 0:
|
||||
_log.info("The progress bar output will appear on node 0.")
|
||||
|
||||
def get_metrics(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
@ -25,7 +24,6 @@ from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
|
|||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
|
||||
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if _RICH_AVAILABLE:
|
||||
|
@ -366,10 +364,6 @@ class RichProgressBar(ProgressBar):
|
|||
if self.progress:
|
||||
self.progress.refresh()
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
if trainer.node_rank > 0 and trainer.local_rank == 0:
|
||||
_log.info("The progress bar output will appear on node 0.")
|
||||
|
||||
@override
|
||||
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._init_progress(trainer)
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
|
@ -34,7 +33,6 @@ import lightning.pytorch as pl
|
|||
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
_PAD_SIZE = 5
|
||||
|
||||
|
||||
|
@ -248,10 +246,6 @@ class TQDMProgressBar(ProgressBar):
|
|||
bar_format=self.BAR_FORMAT,
|
||||
)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
if trainer.node_rank > 0 and trainer.local_rank == 0:
|
||||
_log.info("The progress bar output will appear on node 0.")
|
||||
|
||||
@override
|
||||
def on_sanity_check_start(self, *_: Any) -> None:
|
||||
self.val_progress_bar = self.init_sanity_tqdm()
|
||||
|
|
Loading…
Reference in New Issue