From fbfcc3d871539340aa6d47f444c38726a7c536d8 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 23 Sep 2022 05:05:29 -0700 Subject: [PATCH] Make Trainer readable and debuggable (2/n) (#14862) * clean trainer 2/n * clean trainer 2/n * clean trainer 2/n * clean trainer 2/n * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/pytorch_lightning/trainer/teardown.py | 57 ++++++++++++++++++++++ src/pytorch_lightning/trainer/trainer.py | 58 +++++------------------ 2 files changed, 69 insertions(+), 46 deletions(-) create mode 100644 src/pytorch_lightning/trainer/teardown.py diff --git a/src/pytorch_lightning/trainer/teardown.py b/src/pytorch_lightning/trainer/teardown.py new file mode 100644 index 0000000000..b8f04ab899 --- /dev/null +++ b/src/pytorch_lightning/trainer/teardown.py @@ -0,0 +1,57 @@ +# Copyright Lightning AI. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import traceback +from typing import Any, Callable + +from lightning_lite.utilities.distributed import distributed_available +from pytorch_lightning.trainer.states import TrainerStatus +from pytorch_lightning.utilities.rank_zero import rank_zero_warn + + +def call_and_handle_interrupt(trainer: Any, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any: + r""" + Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict) + as all errors should funnel through them + + Args: + trainer_fn: one of (fit, validate, test, predict) + *args: positional arguments to be passed to the `trainer_fn` + **kwargs: keyword arguments to be passed to `trainer_fn` + """ + try: + if trainer.strategy.launcher is not None: + return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) + else: + return trainer_fn(*args, **kwargs) + # TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise + except KeyboardInterrupt as exception: + rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") + # user could press Ctrl+c many times... only shutdown once + if not trainer.interrupted: + trainer.state.status = TrainerStatus.INTERRUPTED + trainer._call_callback_hooks("on_exception", exception) + for logger in trainer.loggers: + logger.finalize("failed") + except BaseException as exception: + trainer.state.status = TrainerStatus.INTERRUPTED + if distributed_available() and trainer.world_size > 1: + # try syncing remaining processes, kill otherwise + trainer.strategy.reconciliate_processes(traceback.format_exc()) + trainer._call_callback_hooks("on_exception", exception) + for logger in trainer.loggers: + logger.finalize("failed") + trainer._teardown() + # teardown might access the stage so we reset it after + trainer.state.stage = None + raise diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index ce0ed55d34..c21fb226ec 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -16,14 +16,13 @@ import inspect import logging import math import os -import traceback import warnings from argparse import ArgumentParser, Namespace from contextlib import contextmanager from copy import deepcopy from datetime import timedelta from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Type, Union from weakref import proxy import torch @@ -38,7 +37,6 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import get_filesystem from lightning_lite.utilities.data import _auto_add_worker_init_fn -from lightning_lite.utilities.distributed import distributed_available from lightning_lite.utilities.types import _PATH from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning.accelerators import ( @@ -75,6 +73,7 @@ from pytorch_lightning.profilers import ( XLAProfiler, ) from pytorch_lightning.strategies import ParallelStrategy, Strategy +from pytorch_lightning.trainer import teardown from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector @@ -620,43 +619,6 @@ class Trainer( self._last_train_dl_reload_epoch = None self._last_val_dl_reload_epoch: Optional[int] = None - def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any: - r""" - Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict) - as all errors should funnel through them - - Args: - trainer_fn: one of (fit, validate, test, predict) - *args: positional arguments to be passed to the `trainer_fn` - **kwargs: keyword arguments to be passed to `trainer_fn` - """ - try: - if self.strategy.launcher is not None: - return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs) - else: - return trainer_fn(*args, **kwargs) - # TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise - except KeyboardInterrupt as exception: - rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") - # user could press Ctrl+c many times... only shutdown once - if not self.interrupted: - self.state.status = TrainerStatus.INTERRUPTED - self._call_callback_hooks("on_exception", exception) - for logger in self.loggers: - logger.finalize("failed") - except BaseException as exception: - self.state.status = TrainerStatus.INTERRUPTED - if distributed_available() and self.world_size > 1: - # try syncing remaining processes, kill otherwise - self.strategy.reconciliate_processes(traceback.format_exc()) - self._call_callback_hooks("on_exception", exception) - for logger in self.loggers: - logger.finalize("failed") - self._teardown() - # teardown might access the stage so we reset it after - self.state.stage = None - raise - def fit( self, model: "pl.LightningModule", @@ -686,8 +648,8 @@ class Trainer( if not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy._lightning_module = model - self._call_and_handle_interrupt( - self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path + teardown.call_and_handle_interrupt( + self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) def _fit_impl( @@ -766,7 +728,9 @@ class Trainer( if model is not None and not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy._lightning_module = model or self.lightning_module - return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) + return teardown.call_and_handle_interrupt( + self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule + ) def _validate_impl( self, @@ -856,7 +820,9 @@ class Trainer( if model is not None and not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy._lightning_module = model or self.lightning_module - return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) + return teardown.call_and_handle_interrupt( + self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule + ) def _test_impl( self, @@ -945,8 +911,8 @@ class Trainer( if model is not None and not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy._lightning_module = model or self.lightning_module - return self._call_and_handle_interrupt( - self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path + return teardown.call_and_handle_interrupt( + self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path ) def _predict_impl(