From 74ee699dfdb3e2a1ffe44ab0058336e1e2e4420f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 8 Feb 2023 15:00:31 +0100 Subject: [PATCH] Add `Strategy.on_exception` (#16646) --- src/lightning/pytorch/CHANGELOG.md | 3 + src/lightning/pytorch/strategies/ddp.py | 13 +++++ src/lightning/pytorch/strategies/ddp_spawn.py | 13 +++++ src/lightning/pytorch/strategies/strategy.py | 4 ++ src/lightning/pytorch/trainer/call.py | 2 + src/lightning/pytorch/utilities/exceptions.py | 13 +++++ src/lightning/pytorch/utilities/imports.py | 1 + .../strategies/test_ddp_spawn.py | 14 +++++ .../strategies/test_ddp_strategy.py | 19 +++++++ tests/tests_pytorch/trainer/test_trainer.py | 19 ++++++- .../utilities/test_exceptions.py | 55 +++++++++++++++++++ 11 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 tests/tests_pytorch/utilities/test_exceptions.py diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index fd7735c149..21ee565f07 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added suffix option to DDP strategy names to enable `find_unused_parameters=True`, for example `strategy="ddp_find_unused_parameters_true"` ([#16611](https://github.com/Lightning-AI/lightning/pull/16611)) +- Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646)) + + ### Changed - "Native" suffix removal ([#16490](https://github.com/Lightning-AI/lightning/pull/16490)) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index e3204ba6c8..fc694a4046 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -45,6 +45,7 @@ from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.strategy import TBroadcast from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.distributed import register_ddp_comm_hook +from lightning.pytorch.utilities.exceptions import _augment_message from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep @@ -364,6 +365,18 @@ class DDPStrategy(ParallelStrategy): description=f"{cls.__class__.__name__}", ) + def on_exception(self, exception: BaseException) -> None: + _augment_message( + exception, + pattern=".*Expected to have finished reduction in the prior iteration.*", + new_message=( + "It looks like your LightningModule has parameters that were not used in producing the loss returned" + " by training_step. If this is intentional, you must enable the detection of unused parameters in DDP," + " either by setting the string value `strategy='ddp_find_unused_parameters_true'`" + " or by setting the flag in the strategy with `strategy=DDPStrategy(find_unused_parameters=True)`." + ), + ) + def teardown(self) -> None: log.detail(f"{self.__class__.__name__}: tearing down strategy") diff --git a/src/lightning/pytorch/strategies/ddp_spawn.py b/src/lightning/pytorch/strategies/ddp_spawn.py index ba08d7ad1f..d620ae8bf2 100644 --- a/src/lightning/pytorch/strategies/ddp_spawn.py +++ b/src/lightning/pytorch/strategies/ddp_spawn.py @@ -42,6 +42,7 @@ from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.strategy import TBroadcast from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.distributed import register_ddp_comm_hook +from lightning.pytorch.utilities.exceptions import _augment_message from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep @@ -330,6 +331,18 @@ class DDPSpawnStrategy(ParallelStrategy): start_method=start_method, ) + def on_exception(self, exception: BaseException) -> None: + _augment_message( + exception, + pattern=".*Expected to have finished reduction in the prior iteration.*", + new_message=( + "It looks like your LightningModule has parameters that were not used in producing the loss returned" + " by training_step. If this is intentional, you must enable the detection of unused parameters in DDP," + f" either by setting the string value `strategy='ddp_{self._start_method}_find_unused_parameters_true'`" + " or by setting the flag in the strategy with `strategy=DDPSpawnStrategy(find_unused_parameters=True)`." + ), + ) + def teardown(self) -> None: log.detail(f"{self.__class__.__name__}: tearing down strategy") diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 6389fd940f..1469c701ce 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -528,6 +528,10 @@ class Strategy(ABC): """Called in the training loop before anything happens for that batch.""" pass + def on_exception(self, exception: BaseException) -> None: + """Called when the trainer execution is interrupted by an exception.""" + pass + def __getstate__(self) -> Dict: # `LightningOptimizer` overrides `self.__class__` so they cannot be pickled state = dict(vars(self)) # copy diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index d7589b77cd..03915b1b6a 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -48,11 +48,13 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg if not trainer.interrupted: trainer.state.status = TrainerStatus.INTERRUPTED trainer._call_callback_hooks("on_exception", exception) + trainer.strategy.on_exception(exception) for logger in trainer.loggers: logger.finalize("failed") except BaseException as exception: trainer.state.status = TrainerStatus.INTERRUPTED trainer._call_callback_hooks("on_exception", exception) + trainer.strategy.on_exception(exception) for logger in trainer.loggers: logger.finalize("failed") trainer._teardown() diff --git a/src/lightning/pytorch/utilities/exceptions.py b/src/lightning/pytorch/utilities/exceptions.py index a4d694ceb7..e9dc268074 100644 --- a/src/lightning/pytorch/utilities/exceptions.py +++ b/src/lightning/pytorch/utilities/exceptions.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re + from lightning.fabric.utilities.exceptions import MisconfigurationException # noqa: F401 +from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_11_0 class SIGTERMException(SystemExit): @@ -27,3 +30,13 @@ class SIGTERMException(SystemExit): class _TunerExitException(Exception): """Exception used to exit early while tuning.""" + + +def _augment_message(exception: BaseException, pattern: str, new_message: str) -> None: + if _PYTHON_GREATER_EQUAL_3_11_0 and any(re.match(pattern, message, re.DOTALL) for message in exception.args): + exception.add_note(new_message) + else: + # Remove this when Python 3.11 becomes the minimum supported version + exception.args = tuple( + new_message if re.match(pattern, message, re.DOTALL) else message for message in exception.args + ) diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 24491db545..ae10fc88b2 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -20,6 +20,7 @@ from lightning_utilities.core.imports import compare_version, package_available, _PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) +_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11) # duplicated from fabric because HPU is patching it below _TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0") _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn.py b/tests/tests_pytorch/strategies/test_ddp_spawn.py index 205fdc7569..9a015067a8 100644 --- a/tests/tests_pytorch/strategies/test_ddp_spawn.py +++ b/tests/tests_pytorch/strategies/test_ddp_spawn.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest +from torch.multiprocessing import ProcessRaisedException + import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.demos.boring_classes import BoringModel @@ -18,6 +21,7 @@ from lightning.pytorch.trainer import Trainer from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel +from tests_pytorch.strategies.test_ddp_strategy import UnusedParametersModel @RunIf(min_cuda_gpus=2, sklearn=True) @@ -73,3 +77,13 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir): ) trainer.fit(model, train_dataloaders=model.train_dataloader(), val_dataloaders=model.val_dataloader()) assert trainer.state.finished, "DDP doesn't work with dataloaders passed to fit()." + + +def test_ddp_spawn_find_unused_parameters_exception(): + """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning + users.""" + trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp_spawn", max_steps=2) + with pytest.raises( + ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in" + ): + trainer.fit(UnusedParametersModel()) diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py index d57bb5feca..45dc56621c 100644 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py @@ -276,3 +276,22 @@ def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy): # Assert model parameters are identical after loading for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): assert torch.equal(trained_param.to("cpu"), loaded_param) + + +class UnusedParametersModel(BoringModel): + def __init__(self): + super().__init__() + self.intermediate_layer = torch.nn.Linear(32, 32) + + def training_step(self, batch, batch_idx): + with torch.no_grad(): + batch = self.intermediate_layer(batch) + return super().training_step(batch, batch_idx) + + +def test_ddp_strategy_find_unused_parameters_exception(): + """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning + users.""" + trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp", max_steps=2) + with pytest.raises(RuntimeError, match="It looks like your LightningModule has parameters that were not used in"): + trainer.fit(UnusedParametersModel()) diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index f5a0c5040f..ce80d0bb18 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -17,9 +17,10 @@ import math import os import pickle from argparse import Namespace -from contextlib import nullcontext +from contextlib import nullcontext, suppress from copy import deepcopy from pathlib import Path +from unittest import mock from unittest.mock import ANY, call, Mock, patch import cloudpickle @@ -2123,3 +2124,19 @@ def test_trainer_compiled_model(tmp_path, monkeypatch): trainer = Trainer(**trainer_kwargs) with pytest.raises(TypeError, match="must be a `Light"): trainer.fit(object()) + + +@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError]) +def test_trainer_calls_strategy_on_exception(exception_type): + """Test that when an exception occurs, the Trainer lets the strategy process it.""" + exception = exception_type("Test exception") + + class ExceptionModel(BoringModel): + def on_fit_start(self): + raise exception + + trainer = Trainer() + with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock: + with suppress(Exception): + trainer.fit(ExceptionModel()) + on_exception_mock.assert_called_once_with(exception) diff --git a/tests/tests_pytorch/utilities/test_exceptions.py b/tests/tests_pytorch/utilities/test_exceptions.py new file mode 100644 index 0000000000..d58181d2a5 --- /dev/null +++ b/tests/tests_pytorch/utilities/test_exceptions.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from lightning.pytorch.utilities.exceptions import _augment_message +from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_11_0 + + +def test_augment_message(): + # exception without args + exception = Exception() + _augment_message(exception, "", "new message") + assert not exception.args + if _PYTHON_GREATER_EQUAL_3_11_0: + assert not exception.__notes__ + + # exception with one arg + exception = Exception("Test message.") + _augment_message(exception, "Test", "New Test message") + if _PYTHON_GREATER_EQUAL_3_11_0: + assert exception.__notes__ == ["New Test message"] + assert exception.args == ("Test message.",) + else: + assert exception.args == ("New Test message",) + + # pattern matching + exception = Exception("Hello. Test message. Over!") + _augment_message(exception, ".*Test.*Over.*", "New Test message") + if _PYTHON_GREATER_EQUAL_3_11_0: + assert exception.__notes__ == ["New Test message"] + assert exception.args == ("Hello. Test message. Over!",) + else: + assert exception.args == ("New Test message",) + + # exception with multiple args + exception = Exception("Message 1", "Message 2", "Message 3") + _augment_message(exception, "Message 2", "New message 2") + if _PYTHON_GREATER_EQUAL_3_11_0: + assert exception.__notes__ == ["New message 2"] + assert exception.args == ( + "Message 1", + "Message 2", + "Message 3", + ) + else: + assert exception.args == ("Message 1", "New message 2", "Message 3")