Add `Strategy.on_exception` (#16646)

This commit is contained in:
Adrian Wälchli 2023-02-08 15:00:31 +01:00 committed by GitHub
parent 1288e4ccc4
commit 74ee699dfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 155 additions and 1 deletions

View File

@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added suffix option to DDP strategy names to enable `find_unused_parameters=True`, for example `strategy="ddp_find_unused_parameters_true"` ([#16611](https://github.com/Lightning-AI/lightning/pull/16611)) - Added suffix option to DDP strategy names to enable `find_unused_parameters=True`, for example `strategy="ddp_find_unused_parameters_true"` ([#16611](https://github.com/Lightning-AI/lightning/pull/16611))
- Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646))
### Changed ### Changed
- "Native" suffix removal ([#16490](https://github.com/Lightning-AI/lightning/pull/16490)) - "Native" suffix removal ([#16490](https://github.com/Lightning-AI/lightning/pull/16490))

View File

@ -45,6 +45,7 @@ from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import TBroadcast from lightning.pytorch.strategies.strategy import TBroadcast
from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.distributed import register_ddp_comm_hook from lightning.pytorch.utilities.distributed import register_ddp_comm_hook
from lightning.pytorch.utilities.exceptions import _augment_message
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only
from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
@ -364,6 +365,18 @@ class DDPStrategy(ParallelStrategy):
description=f"{cls.__class__.__name__}", description=f"{cls.__class__.__name__}",
) )
def on_exception(self, exception: BaseException) -> None:
_augment_message(
exception,
pattern=".*Expected to have finished reduction in the prior iteration.*",
new_message=(
"It looks like your LightningModule has parameters that were not used in producing the loss returned"
" by training_step. If this is intentional, you must enable the detection of unused parameters in DDP,"
" either by setting the string value `strategy='ddp_find_unused_parameters_true'`"
" or by setting the flag in the strategy with `strategy=DDPStrategy(find_unused_parameters=True)`."
),
)
def teardown(self) -> None: def teardown(self) -> None:
log.detail(f"{self.__class__.__name__}: tearing down strategy") log.detail(f"{self.__class__.__name__}: tearing down strategy")

View File

@ -42,6 +42,7 @@ from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import TBroadcast from lightning.pytorch.strategies.strategy import TBroadcast
from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.distributed import register_ddp_comm_hook from lightning.pytorch.utilities.distributed import register_ddp_comm_hook
from lightning.pytorch.utilities.exceptions import _augment_message
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only
from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
@ -330,6 +331,18 @@ class DDPSpawnStrategy(ParallelStrategy):
start_method=start_method, start_method=start_method,
) )
def on_exception(self, exception: BaseException) -> None:
_augment_message(
exception,
pattern=".*Expected to have finished reduction in the prior iteration.*",
new_message=(
"It looks like your LightningModule has parameters that were not used in producing the loss returned"
" by training_step. If this is intentional, you must enable the detection of unused parameters in DDP,"
f" either by setting the string value `strategy='ddp_{self._start_method}_find_unused_parameters_true'`"
" or by setting the flag in the strategy with `strategy=DDPSpawnStrategy(find_unused_parameters=True)`."
),
)
def teardown(self) -> None: def teardown(self) -> None:
log.detail(f"{self.__class__.__name__}: tearing down strategy") log.detail(f"{self.__class__.__name__}: tearing down strategy")

View File

@ -528,6 +528,10 @@ class Strategy(ABC):
"""Called in the training loop before anything happens for that batch.""" """Called in the training loop before anything happens for that batch."""
pass pass
def on_exception(self, exception: BaseException) -> None:
"""Called when the trainer execution is interrupted by an exception."""
pass
def __getstate__(self) -> Dict: def __getstate__(self) -> Dict:
# `LightningOptimizer` overrides `self.__class__` so they cannot be pickled # `LightningOptimizer` overrides `self.__class__` so they cannot be pickled
state = dict(vars(self)) # copy state = dict(vars(self)) # copy

View File

@ -48,11 +48,13 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
if not trainer.interrupted: if not trainer.interrupted:
trainer.state.status = TrainerStatus.INTERRUPTED trainer.state.status = TrainerStatus.INTERRUPTED
trainer._call_callback_hooks("on_exception", exception) trainer._call_callback_hooks("on_exception", exception)
trainer.strategy.on_exception(exception)
for logger in trainer.loggers: for logger in trainer.loggers:
logger.finalize("failed") logger.finalize("failed")
except BaseException as exception: except BaseException as exception:
trainer.state.status = TrainerStatus.INTERRUPTED trainer.state.status = TrainerStatus.INTERRUPTED
trainer._call_callback_hooks("on_exception", exception) trainer._call_callback_hooks("on_exception", exception)
trainer.strategy.on_exception(exception)
for logger in trainer.loggers: for logger in trainer.loggers:
logger.finalize("failed") logger.finalize("failed")
trainer._teardown() trainer._teardown()

View File

@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from lightning.fabric.utilities.exceptions import MisconfigurationException # noqa: F401 from lightning.fabric.utilities.exceptions import MisconfigurationException # noqa: F401
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_11_0
class SIGTERMException(SystemExit): class SIGTERMException(SystemExit):
@ -27,3 +30,13 @@ class SIGTERMException(SystemExit):
class _TunerExitException(Exception): class _TunerExitException(Exception):
"""Exception used to exit early while tuning.""" """Exception used to exit early while tuning."""
def _augment_message(exception: BaseException, pattern: str, new_message: str) -> None:
if _PYTHON_GREATER_EQUAL_3_11_0 and any(re.match(pattern, message, re.DOTALL) for message in exception.args):
exception.add_note(new_message)
else:
# Remove this when Python 3.11 becomes the minimum supported version
exception.args = tuple(
new_message if re.match(pattern, message, re.DOTALL) else message for message in exception.args
)

View File

@ -20,6 +20,7 @@ from lightning_utilities.core.imports import compare_version, package_available,
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) _PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
# duplicated from fabric because HPU is patching it below # duplicated from fabric because HPU is patching it below
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0") _TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")

View File

@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pytest
from torch.multiprocessing import ProcessRaisedException
import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.pipelines as tpipes
from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.demos.boring_classes import BoringModel
@ -18,6 +21,7 @@ from lightning.pytorch.trainer import Trainer
from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel from tests_pytorch.helpers.simple_models import ClassificationModel
from tests_pytorch.strategies.test_ddp_strategy import UnusedParametersModel
@RunIf(min_cuda_gpus=2, sklearn=True) @RunIf(min_cuda_gpus=2, sklearn=True)
@ -73,3 +77,13 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
) )
trainer.fit(model, train_dataloaders=model.train_dataloader(), val_dataloaders=model.val_dataloader()) trainer.fit(model, train_dataloaders=model.train_dataloader(), val_dataloaders=model.val_dataloader())
assert trainer.state.finished, "DDP doesn't work with dataloaders passed to fit()." assert trainer.state.finished, "DDP doesn't work with dataloaders passed to fit()."
def test_ddp_spawn_find_unused_parameters_exception():
"""Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning
users."""
trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp_spawn", max_steps=2)
with pytest.raises(
ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in"
):
trainer.fit(UnusedParametersModel())

View File

@ -276,3 +276,22 @@ def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy):
# Assert model parameters are identical after loading # Assert model parameters are identical after loading
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(trained_param.to("cpu"), loaded_param) assert torch.equal(trained_param.to("cpu"), loaded_param)
class UnusedParametersModel(BoringModel):
def __init__(self):
super().__init__()
self.intermediate_layer = torch.nn.Linear(32, 32)
def training_step(self, batch, batch_idx):
with torch.no_grad():
batch = self.intermediate_layer(batch)
return super().training_step(batch, batch_idx)
def test_ddp_strategy_find_unused_parameters_exception():
"""Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning
users."""
trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp", max_steps=2)
with pytest.raises(RuntimeError, match="It looks like your LightningModule has parameters that were not used in"):
trainer.fit(UnusedParametersModel())

View File

@ -17,9 +17,10 @@ import math
import os import os
import pickle import pickle
from argparse import Namespace from argparse import Namespace
from contextlib import nullcontext from contextlib import nullcontext, suppress
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from unittest import mock
from unittest.mock import ANY, call, Mock, patch from unittest.mock import ANY, call, Mock, patch
import cloudpickle import cloudpickle
@ -2123,3 +2124,19 @@ def test_trainer_compiled_model(tmp_path, monkeypatch):
trainer = Trainer(**trainer_kwargs) trainer = Trainer(**trainer_kwargs)
with pytest.raises(TypeError, match="must be a `Light"): with pytest.raises(TypeError, match="must be a `Light"):
trainer.fit(object()) trainer.fit(object())
@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError])
def test_trainer_calls_strategy_on_exception(exception_type):
"""Test that when an exception occurs, the Trainer lets the strategy process it."""
exception = exception_type("Test exception")
class ExceptionModel(BoringModel):
def on_fit_start(self):
raise exception
trainer = Trainer()
with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock:
with suppress(Exception):
trainer.fit(ExceptionModel())
on_exception_mock.assert_called_once_with(exception)

View File

@ -0,0 +1,55 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning.pytorch.utilities.exceptions import _augment_message
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_11_0
def test_augment_message():
# exception without args
exception = Exception()
_augment_message(exception, "", "new message")
assert not exception.args
if _PYTHON_GREATER_EQUAL_3_11_0:
assert not exception.__notes__
# exception with one arg
exception = Exception("Test message.")
_augment_message(exception, "Test", "New Test message")
if _PYTHON_GREATER_EQUAL_3_11_0:
assert exception.__notes__ == ["New Test message"]
assert exception.args == ("Test message.",)
else:
assert exception.args == ("New Test message",)
# pattern matching
exception = Exception("Hello. Test message. Over!")
_augment_message(exception, ".*Test.*Over.*", "New Test message")
if _PYTHON_GREATER_EQUAL_3_11_0:
assert exception.__notes__ == ["New Test message"]
assert exception.args == ("Hello. Test message. Over!",)
else:
assert exception.args == ("New Test message",)
# exception with multiple args
exception = Exception("Message 1", "Message 2", "Message 3")
_augment_message(exception, "Message 2", "New message 2")
if _PYTHON_GREATER_EQUAL_3_11_0:
assert exception.__notes__ == ["New message 2"]
assert exception.args == (
"Message 1",
"Message 2",
"Message 3",
)
else:
assert exception.args == ("Message 1", "New message 2", "Message 3")