Make Trainer readable and debuggable (2/n) (#14862)

* clean trainer 2/n

* clean trainer 2/n

* clean trainer 2/n

* clean trainer 2/n

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
William Falcon 2022-09-23 05:05:29 -07:00 committed by GitHub
parent c77d4a8394
commit fbfcc3d871
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 46 deletions

View File

@ -0,0 +1,57 @@
# Copyright Lightning AI.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from typing import Any, Callable
from lightning_lite.utilities.distributed import distributed_available
from pytorch_lightning.trainer.states import TrainerStatus
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
def call_and_handle_interrupt(trainer: Any, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
r"""
Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
as all errors should funnel through them
Args:
trainer_fn: one of (fit, validate, test, predict)
*args: positional arguments to be passed to the `trainer_fn`
**kwargs: keyword arguments to be passed to `trainer_fn`
"""
try:
if trainer.strategy.launcher is not None:
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
else:
return trainer_fn(*args, **kwargs)
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not trainer.interrupted:
trainer.state.status = TrainerStatus.INTERRUPTED
trainer._call_callback_hooks("on_exception", exception)
for logger in trainer.loggers:
logger.finalize("failed")
except BaseException as exception:
trainer.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and trainer.world_size > 1:
# try syncing remaining processes, kill otherwise
trainer.strategy.reconciliate_processes(traceback.format_exc())
trainer._call_callback_hooks("on_exception", exception)
for logger in trainer.loggers:
logger.finalize("failed")
trainer._teardown()
# teardown might access the stage so we reset it after
trainer.state.stage = None
raise

View File

@ -16,14 +16,13 @@ import inspect
import logging import logging
import math import math
import os import os
import traceback
import warnings import warnings
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, Union from typing import Any, Dict, Generator, Iterable, List, Optional, Type, Union
from weakref import proxy from weakref import proxy
import torch import torch
@ -38,7 +37,6 @@ from torch.utils.data import DataLoader
import pytorch_lightning as pl import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import get_filesystem from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.data import _auto_add_worker_init_fn from lightning_lite.utilities.data import _auto_add_worker_init_fn
from lightning_lite.utilities.distributed import distributed_available
from lightning_lite.utilities.types import _PATH from lightning_lite.utilities.types import _PATH
from lightning_lite.utilities.warnings import PossibleUserWarning from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.accelerators import ( from pytorch_lightning.accelerators import (
@ -75,6 +73,7 @@ from pytorch_lightning.profilers import (
XLAProfiler, XLAProfiler,
) )
from pytorch_lightning.strategies import ParallelStrategy, Strategy from pytorch_lightning.strategies import ParallelStrategy, Strategy
from pytorch_lightning.trainer import teardown
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
@ -620,43 +619,6 @@ class Trainer(
self._last_train_dl_reload_epoch = None self._last_train_dl_reload_epoch = None
self._last_val_dl_reload_epoch: Optional[int] = None self._last_val_dl_reload_epoch: Optional[int] = None
def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
r"""
Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
as all errors should funnel through them
Args:
trainer_fn: one of (fit, validate, test, predict)
*args: positional arguments to be passed to the `trainer_fn`
**kwargs: keyword arguments to be passed to `trainer_fn`
"""
try:
if self.strategy.launcher is not None:
return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
else:
return trainer_fn(*args, **kwargs)
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not self.interrupted:
self.state.status = TrainerStatus.INTERRUPTED
self._call_callback_hooks("on_exception", exception)
for logger in self.loggers:
logger.finalize("failed")
except BaseException as exception:
self.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and self.world_size > 1:
# try syncing remaining processes, kill otherwise
self.strategy.reconciliate_processes(traceback.format_exc())
self._call_callback_hooks("on_exception", exception)
for logger in self.loggers:
logger.finalize("failed")
self._teardown()
# teardown might access the stage so we reset it after
self.state.stage = None
raise
def fit( def fit(
self, self,
model: "pl.LightningModule", model: "pl.LightningModule",
@ -686,8 +648,8 @@ class Trainer(
if not isinstance(model, pl.LightningModule): if not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}") raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model self.strategy._lightning_module = model
self._call_and_handle_interrupt( teardown.call_and_handle_interrupt(
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
) )
def _fit_impl( def _fit_impl(
@ -766,7 +728,9 @@ class Trainer(
if model is not None and not isinstance(model, pl.LightningModule): if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}") raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model or self.lightning_module self.strategy._lightning_module = model or self.lightning_module
return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) return teardown.call_and_handle_interrupt(
self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule
)
def _validate_impl( def _validate_impl(
self, self,
@ -856,7 +820,9 @@ class Trainer(
if model is not None and not isinstance(model, pl.LightningModule): if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}") raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model or self.lightning_module self.strategy._lightning_module = model or self.lightning_module
return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) return teardown.call_and_handle_interrupt(
self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
)
def _test_impl( def _test_impl(
self, self,
@ -945,8 +911,8 @@ class Trainer(
if model is not None and not isinstance(model, pl.LightningModule): if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}") raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model or self.lightning_module self.strategy._lightning_module = model or self.lightning_module
return self._call_and_handle_interrupt( return teardown.call_and_handle_interrupt(
self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
) )
def _predict_impl( def _predict_impl(