fix
This commit is contained in:
parent
ecf7933696
commit
b840cd4ecc
|
@ -297,13 +297,14 @@ def _init_dist_connection(
|
||||||
# PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
|
# PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
|
||||||
atexit.register(_destroy_dist_connection)
|
atexit.register(_destroy_dist_connection)
|
||||||
|
|
||||||
# On rank=0 let everyone know training is starting
|
# On local_rank=0 let everyone know training is starting
|
||||||
rank_zero_info(
|
if cluster_environment.local_rank() == 0:
|
||||||
f"{'-' * 100}\n"
|
log.info(
|
||||||
f"distributed_backend={torch_distributed_backend}\n"
|
f"{'-' * 100}\n"
|
||||||
f"All distributed processes registered. Starting with {world_size} processes\n"
|
f"Distributed backend: {torch_distributed_backend.upper()}\n"
|
||||||
f"{'-' * 100}\n"
|
f"All distributed processes registered. Starting with {world_size} processes\n"
|
||||||
)
|
f"{'-' * 100}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _destroy_dist_connection() -> None:
|
def _destroy_dist_connection() -> None:
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
@ -19,6 +20,8 @@ import lightning.pytorch as pl
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProgressBar(Callback):
|
class ProgressBar(Callback):
|
||||||
r"""The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback` that keeps
|
r"""The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback` that keeps
|
||||||
|
@ -173,6 +176,8 @@ class ProgressBar(Callback):
|
||||||
self._trainer = trainer
|
self._trainer = trainer
|
||||||
if not trainer.is_global_zero:
|
if not trainer.is_global_zero:
|
||||||
self.disable()
|
self.disable()
|
||||||
|
if trainer.node_rank > 0 and trainer.local_rank == 0:
|
||||||
|
_log.info("The progress bar output will appear on node 0.")
|
||||||
|
|
||||||
def get_metrics(
|
def get_metrics(
|
||||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
|
||||||
|
|
|
@ -11,7 +11,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
@ -25,7 +24,6 @@ from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
|
||||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||||
|
|
||||||
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
|
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
|
||||||
_log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
if _RICH_AVAILABLE:
|
if _RICH_AVAILABLE:
|
||||||
|
@ -366,10 +364,6 @@ class RichProgressBar(ProgressBar):
|
||||||
if self.progress:
|
if self.progress:
|
||||||
self.progress.refresh()
|
self.progress.refresh()
|
||||||
|
|
||||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
|
||||||
if trainer.node_rank > 0 and trainer.local_rank == 0:
|
|
||||||
_log.info("The progress bar output will appear on node 0.")
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
self._init_progress(trainer)
|
self._init_progress(trainer)
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
@ -34,7 +33,6 @@ import lightning.pytorch as pl
|
||||||
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
|
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
|
||||||
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
|
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
|
||||||
_PAD_SIZE = 5
|
_PAD_SIZE = 5
|
||||||
|
|
||||||
|
|
||||||
|
@ -248,10 +246,6 @@ class TQDMProgressBar(ProgressBar):
|
||||||
bar_format=self.BAR_FORMAT,
|
bar_format=self.BAR_FORMAT,
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
|
||||||
if trainer.node_rank > 0 and trainer.local_rank == 0:
|
|
||||||
_log.info("The progress bar output will appear on node 0.")
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_sanity_check_start(self, *_: Any) -> None:
|
def on_sanity_check_start(self, *_: Any) -> None:
|
||||||
self.val_progress_bar = self.init_sanity_tqdm()
|
self.val_progress_bar = self.init_sanity_tqdm()
|
||||||
|
|
Loading…
Reference in New Issue