Add `Strategy.on_exception` (#16646)
This commit is contained in:
parent
1288e4ccc4
commit
74ee699dfd
|
@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added suffix option to DDP strategy names to enable `find_unused_parameters=True`, for example `strategy="ddp_find_unused_parameters_true"` ([#16611](https://github.com/Lightning-AI/lightning/pull/16611))
|
||||
|
||||
|
||||
- Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- "Native" suffix removal ([#16490](https://github.com/Lightning-AI/lightning/pull/16490))
|
||||
|
|
|
@ -45,6 +45,7 @@ from lightning.pytorch.strategies.parallel import ParallelStrategy
|
|||
from lightning.pytorch.strategies.strategy import TBroadcast
|
||||
from lightning.pytorch.trainer.states import TrainerFn
|
||||
from lightning.pytorch.utilities.distributed import register_ddp_comm_hook
|
||||
from lightning.pytorch.utilities.exceptions import _augment_message
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only
|
||||
from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
|
||||
|
||||
|
@ -364,6 +365,18 @@ class DDPStrategy(ParallelStrategy):
|
|||
description=f"{cls.__class__.__name__}",
|
||||
)
|
||||
|
||||
def on_exception(self, exception: BaseException) -> None:
|
||||
_augment_message(
|
||||
exception,
|
||||
pattern=".*Expected to have finished reduction in the prior iteration.*",
|
||||
new_message=(
|
||||
"It looks like your LightningModule has parameters that were not used in producing the loss returned"
|
||||
" by training_step. If this is intentional, you must enable the detection of unused parameters in DDP,"
|
||||
" either by setting the string value `strategy='ddp_find_unused_parameters_true'`"
|
||||
" or by setting the flag in the strategy with `strategy=DDPStrategy(find_unused_parameters=True)`."
|
||||
),
|
||||
)
|
||||
|
||||
def teardown(self) -> None:
|
||||
log.detail(f"{self.__class__.__name__}: tearing down strategy")
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ from lightning.pytorch.strategies.parallel import ParallelStrategy
|
|||
from lightning.pytorch.strategies.strategy import TBroadcast
|
||||
from lightning.pytorch.trainer.states import TrainerFn
|
||||
from lightning.pytorch.utilities.distributed import register_ddp_comm_hook
|
||||
from lightning.pytorch.utilities.exceptions import _augment_message
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only
|
||||
from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
|
||||
|
||||
|
@ -330,6 +331,18 @@ class DDPSpawnStrategy(ParallelStrategy):
|
|||
start_method=start_method,
|
||||
)
|
||||
|
||||
def on_exception(self, exception: BaseException) -> None:
|
||||
_augment_message(
|
||||
exception,
|
||||
pattern=".*Expected to have finished reduction in the prior iteration.*",
|
||||
new_message=(
|
||||
"It looks like your LightningModule has parameters that were not used in producing the loss returned"
|
||||
" by training_step. If this is intentional, you must enable the detection of unused parameters in DDP,"
|
||||
f" either by setting the string value `strategy='ddp_{self._start_method}_find_unused_parameters_true'`"
|
||||
" or by setting the flag in the strategy with `strategy=DDPSpawnStrategy(find_unused_parameters=True)`."
|
||||
),
|
||||
)
|
||||
|
||||
def teardown(self) -> None:
|
||||
log.detail(f"{self.__class__.__name__}: tearing down strategy")
|
||||
|
||||
|
|
|
@ -528,6 +528,10 @@ class Strategy(ABC):
|
|||
"""Called in the training loop before anything happens for that batch."""
|
||||
pass
|
||||
|
||||
def on_exception(self, exception: BaseException) -> None:
|
||||
"""Called when the trainer execution is interrupted by an exception."""
|
||||
pass
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
# `LightningOptimizer` overrides `self.__class__` so they cannot be pickled
|
||||
state = dict(vars(self)) # copy
|
||||
|
|
|
@ -48,11 +48,13 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
|
|||
if not trainer.interrupted:
|
||||
trainer.state.status = TrainerStatus.INTERRUPTED
|
||||
trainer._call_callback_hooks("on_exception", exception)
|
||||
trainer.strategy.on_exception(exception)
|
||||
for logger in trainer.loggers:
|
||||
logger.finalize("failed")
|
||||
except BaseException as exception:
|
||||
trainer.state.status = TrainerStatus.INTERRUPTED
|
||||
trainer._call_callback_hooks("on_exception", exception)
|
||||
trainer.strategy.on_exception(exception)
|
||||
for logger in trainer.loggers:
|
||||
logger.finalize("failed")
|
||||
trainer._teardown()
|
||||
|
|
|
@ -11,7 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
from lightning.fabric.utilities.exceptions import MisconfigurationException # noqa: F401
|
||||
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_11_0
|
||||
|
||||
|
||||
class SIGTERMException(SystemExit):
|
||||
|
@ -27,3 +30,13 @@ class SIGTERMException(SystemExit):
|
|||
|
||||
class _TunerExitException(Exception):
|
||||
"""Exception used to exit early while tuning."""
|
||||
|
||||
|
||||
def _augment_message(exception: BaseException, pattern: str, new_message: str) -> None:
|
||||
if _PYTHON_GREATER_EQUAL_3_11_0 and any(re.match(pattern, message, re.DOTALL) for message in exception.args):
|
||||
exception.add_note(new_message)
|
||||
else:
|
||||
# Remove this when Python 3.11 becomes the minimum supported version
|
||||
exception.args = tuple(
|
||||
new_message if re.match(pattern, message, re.DOTALL) else message for message in exception.args
|
||||
)
|
||||
|
|
|
@ -20,6 +20,7 @@ from lightning_utilities.core.imports import compare_version, package_available,
|
|||
|
||||
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
|
||||
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
|
||||
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
|
||||
# duplicated from fabric because HPU is patching it below
|
||||
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
|
||||
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
|
||||
|
|
|
@ -11,6 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
from torch.multiprocessing import ProcessRaisedException
|
||||
|
||||
import tests_pytorch.helpers.pipelines as tpipes
|
||||
from lightning.pytorch.callbacks import EarlyStopping
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
@ -18,6 +21,7 @@ from lightning.pytorch.trainer import Trainer
|
|||
from tests_pytorch.helpers.datamodules import ClassifDataModule
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
from tests_pytorch.helpers.simple_models import ClassificationModel
|
||||
from tests_pytorch.strategies.test_ddp_strategy import UnusedParametersModel
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, sklearn=True)
|
||||
|
@ -73,3 +77,13 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
|
|||
)
|
||||
trainer.fit(model, train_dataloaders=model.train_dataloader(), val_dataloaders=model.val_dataloader())
|
||||
assert trainer.state.finished, "DDP doesn't work with dataloaders passed to fit()."
|
||||
|
||||
|
||||
def test_ddp_spawn_find_unused_parameters_exception():
|
||||
"""Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning
|
||||
users."""
|
||||
trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp_spawn", max_steps=2)
|
||||
with pytest.raises(
|
||||
ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in"
|
||||
):
|
||||
trainer.fit(UnusedParametersModel())
|
||||
|
|
|
@ -276,3 +276,22 @@ def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy):
|
|||
# Assert model parameters are identical after loading
|
||||
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
|
||||
assert torch.equal(trained_param.to("cpu"), loaded_param)
|
||||
|
||||
|
||||
class UnusedParametersModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.intermediate_layer = torch.nn.Linear(32, 32)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
with torch.no_grad():
|
||||
batch = self.intermediate_layer(batch)
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
|
||||
def test_ddp_strategy_find_unused_parameters_exception():
|
||||
"""Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning
|
||||
users."""
|
||||
trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp", max_steps=2)
|
||||
with pytest.raises(RuntimeError, match="It looks like your LightningModule has parameters that were not used in"):
|
||||
trainer.fit(UnusedParametersModel())
|
||||
|
|
|
@ -17,9 +17,10 @@ import math
|
|||
import os
|
||||
import pickle
|
||||
from argparse import Namespace
|
||||
from contextlib import nullcontext
|
||||
from contextlib import nullcontext, suppress
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, call, Mock, patch
|
||||
|
||||
import cloudpickle
|
||||
|
@ -2123,3 +2124,19 @@ def test_trainer_compiled_model(tmp_path, monkeypatch):
|
|||
trainer = Trainer(**trainer_kwargs)
|
||||
with pytest.raises(TypeError, match="must be a `Light"):
|
||||
trainer.fit(object())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError])
|
||||
def test_trainer_calls_strategy_on_exception(exception_type):
|
||||
"""Test that when an exception occurs, the Trainer lets the strategy process it."""
|
||||
exception = exception_type("Test exception")
|
||||
|
||||
class ExceptionModel(BoringModel):
|
||||
def on_fit_start(self):
|
||||
raise exception
|
||||
|
||||
trainer = Trainer()
|
||||
with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock:
|
||||
with suppress(Exception):
|
||||
trainer.fit(ExceptionModel())
|
||||
on_exception_mock.assert_called_once_with(exception)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from lightning.pytorch.utilities.exceptions import _augment_message
|
||||
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_11_0
|
||||
|
||||
|
||||
def test_augment_message():
|
||||
# exception without args
|
||||
exception = Exception()
|
||||
_augment_message(exception, "", "new message")
|
||||
assert not exception.args
|
||||
if _PYTHON_GREATER_EQUAL_3_11_0:
|
||||
assert not exception.__notes__
|
||||
|
||||
# exception with one arg
|
||||
exception = Exception("Test message.")
|
||||
_augment_message(exception, "Test", "New Test message")
|
||||
if _PYTHON_GREATER_EQUAL_3_11_0:
|
||||
assert exception.__notes__ == ["New Test message"]
|
||||
assert exception.args == ("Test message.",)
|
||||
else:
|
||||
assert exception.args == ("New Test message",)
|
||||
|
||||
# pattern matching
|
||||
exception = Exception("Hello. Test message. Over!")
|
||||
_augment_message(exception, ".*Test.*Over.*", "New Test message")
|
||||
if _PYTHON_GREATER_EQUAL_3_11_0:
|
||||
assert exception.__notes__ == ["New Test message"]
|
||||
assert exception.args == ("Hello. Test message. Over!",)
|
||||
else:
|
||||
assert exception.args == ("New Test message",)
|
||||
|
||||
# exception with multiple args
|
||||
exception = Exception("Message 1", "Message 2", "Message 3")
|
||||
_augment_message(exception, "Message 2", "New message 2")
|
||||
if _PYTHON_GREATER_EQUAL_3_11_0:
|
||||
assert exception.__notes__ == ["New message 2"]
|
||||
assert exception.args == (
|
||||
"Message 1",
|
||||
"Message 2",
|
||||
"Message 3",
|
||||
)
|
||||
else:
|
||||
assert exception.args == ("Message 1", "New message 2", "Message 3")
|
Loading…
Reference in New Issue