Make Trainer readable and debuggable (2/n) (#14862)

* clean trainer 2/n

* clean trainer 2/n

* clean trainer 2/n

* clean trainer 2/n

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
William Falcon 2022-09-23 05:05:29 -07:00 committed by GitHub
parent c77d4a8394
commit fbfcc3d871
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 46 deletions

View File

@ -0,0 +1,57 @@
# Copyright Lightning AI.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from typing import Any, Callable
from lightning_lite.utilities.distributed import distributed_available
from pytorch_lightning.trainer.states import TrainerStatus
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
def call_and_handle_interrupt(trainer: Any, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
r"""
Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
as all errors should funnel through them
Args:
trainer_fn: one of (fit, validate, test, predict)
*args: positional arguments to be passed to the `trainer_fn`
**kwargs: keyword arguments to be passed to `trainer_fn`
"""
try:
if trainer.strategy.launcher is not None:
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
else:
return trainer_fn(*args, **kwargs)
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not trainer.interrupted:
trainer.state.status = TrainerStatus.INTERRUPTED
trainer._call_callback_hooks("on_exception", exception)
for logger in trainer.loggers:
logger.finalize("failed")
except BaseException as exception:
trainer.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and trainer.world_size > 1:
# try syncing remaining processes, kill otherwise
trainer.strategy.reconciliate_processes(traceback.format_exc())
trainer._call_callback_hooks("on_exception", exception)
for logger in trainer.loggers:
logger.finalize("failed")
trainer._teardown()
# teardown might access the stage so we reset it after
trainer.state.stage = None
raise

View File

@ -16,14 +16,13 @@ import inspect
import logging
import math
import os
import traceback
import warnings
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Type, Union
from weakref import proxy
import torch
@ -38,7 +37,6 @@ from torch.utils.data import DataLoader
import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.data import _auto_add_worker_init_fn
from lightning_lite.utilities.distributed import distributed_available
from lightning_lite.utilities.types import _PATH
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.accelerators import (
@ -75,6 +73,7 @@ from pytorch_lightning.profilers import (
XLAProfiler,
)
from pytorch_lightning.strategies import ParallelStrategy, Strategy
from pytorch_lightning.trainer import teardown
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
@ -620,43 +619,6 @@ class Trainer(
self._last_train_dl_reload_epoch = None
self._last_val_dl_reload_epoch: Optional[int] = None
def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
r"""
Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
as all errors should funnel through them
Args:
trainer_fn: one of (fit, validate, test, predict)
*args: positional arguments to be passed to the `trainer_fn`
**kwargs: keyword arguments to be passed to `trainer_fn`
"""
try:
if self.strategy.launcher is not None:
return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
else:
return trainer_fn(*args, **kwargs)
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not self.interrupted:
self.state.status = TrainerStatus.INTERRUPTED
self._call_callback_hooks("on_exception", exception)
for logger in self.loggers:
logger.finalize("failed")
except BaseException as exception:
self.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and self.world_size > 1:
# try syncing remaining processes, kill otherwise
self.strategy.reconciliate_processes(traceback.format_exc())
self._call_callback_hooks("on_exception", exception)
for logger in self.loggers:
logger.finalize("failed")
self._teardown()
# teardown might access the stage so we reset it after
self.state.stage = None
raise
def fit(
self,
model: "pl.LightningModule",
@ -686,8 +648,8 @@ class Trainer(
if not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model
self._call_and_handle_interrupt(
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
teardown.call_and_handle_interrupt(
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)
def _fit_impl(
@ -766,7 +728,9 @@ class Trainer(
if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model or self.lightning_module
return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule)
return teardown.call_and_handle_interrupt(
self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule
)
def _validate_impl(
self,
@ -856,7 +820,9 @@ class Trainer(
if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model or self.lightning_module
return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
return teardown.call_and_handle_interrupt(
self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
)
def _test_impl(
self,
@ -945,8 +911,8 @@ class Trainer(
if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model or self.lightning_module
return self._call_and_handle_interrupt(
self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
return teardown.call_and_handle_interrupt(
self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
)
def _predict_impl(