Make Trainer readable and debuggable (2/n) (#14862)
* clean trainer 2/n * clean trainer 2/n * clean trainer 2/n * clean trainer 2/n * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c77d4a8394
commit
fbfcc3d871
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright Lightning AI.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import traceback
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from lightning_lite.utilities.distributed import distributed_available
|
||||||
|
from pytorch_lightning.trainer.states import TrainerStatus
|
||||||
|
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||||
|
|
||||||
|
|
||||||
|
def call_and_handle_interrupt(trainer: Any, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
r"""
|
||||||
|
Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
|
||||||
|
as all errors should funnel through them
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainer_fn: one of (fit, validate, test, predict)
|
||||||
|
*args: positional arguments to be passed to the `trainer_fn`
|
||||||
|
**kwargs: keyword arguments to be passed to `trainer_fn`
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if trainer.strategy.launcher is not None:
|
||||||
|
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
|
||||||
|
else:
|
||||||
|
return trainer_fn(*args, **kwargs)
|
||||||
|
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
|
||||||
|
except KeyboardInterrupt as exception:
|
||||||
|
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
|
||||||
|
# user could press Ctrl+c many times... only shutdown once
|
||||||
|
if not trainer.interrupted:
|
||||||
|
trainer.state.status = TrainerStatus.INTERRUPTED
|
||||||
|
trainer._call_callback_hooks("on_exception", exception)
|
||||||
|
for logger in trainer.loggers:
|
||||||
|
logger.finalize("failed")
|
||||||
|
except BaseException as exception:
|
||||||
|
trainer.state.status = TrainerStatus.INTERRUPTED
|
||||||
|
if distributed_available() and trainer.world_size > 1:
|
||||||
|
# try syncing remaining processes, kill otherwise
|
||||||
|
trainer.strategy.reconciliate_processes(traceback.format_exc())
|
||||||
|
trainer._call_callback_hooks("on_exception", exception)
|
||||||
|
for logger in trainer.loggers:
|
||||||
|
logger.finalize("failed")
|
||||||
|
trainer._teardown()
|
||||||
|
# teardown might access the stage so we reset it after
|
||||||
|
trainer.state.stage = None
|
||||||
|
raise
|
|
@ -16,14 +16,13 @@ import inspect
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import traceback
|
|
||||||
import warnings
|
import warnings
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, Union
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Type, Union
|
||||||
from weakref import proxy
|
from weakref import proxy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -38,7 +37,6 @@ from torch.utils.data import DataLoader
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||||
from lightning_lite.utilities.data import _auto_add_worker_init_fn
|
from lightning_lite.utilities.data import _auto_add_worker_init_fn
|
||||||
from lightning_lite.utilities.distributed import distributed_available
|
|
||||||
from lightning_lite.utilities.types import _PATH
|
from lightning_lite.utilities.types import _PATH
|
||||||
from lightning_lite.utilities.warnings import PossibleUserWarning
|
from lightning_lite.utilities.warnings import PossibleUserWarning
|
||||||
from pytorch_lightning.accelerators import (
|
from pytorch_lightning.accelerators import (
|
||||||
|
@ -75,6 +73,7 @@ from pytorch_lightning.profilers import (
|
||||||
XLAProfiler,
|
XLAProfiler,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.strategies import ParallelStrategy, Strategy
|
from pytorch_lightning.strategies import ParallelStrategy, Strategy
|
||||||
|
from pytorch_lightning.trainer import teardown
|
||||||
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
|
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
|
||||||
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
|
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
|
||||||
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
|
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
|
||||||
|
@ -620,43 +619,6 @@ class Trainer(
|
||||||
self._last_train_dl_reload_epoch = None
|
self._last_train_dl_reload_epoch = None
|
||||||
self._last_val_dl_reload_epoch: Optional[int] = None
|
self._last_val_dl_reload_epoch: Optional[int] = None
|
||||||
|
|
||||||
def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
||||||
r"""
|
|
||||||
Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
|
|
||||||
as all errors should funnel through them
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trainer_fn: one of (fit, validate, test, predict)
|
|
||||||
*args: positional arguments to be passed to the `trainer_fn`
|
|
||||||
**kwargs: keyword arguments to be passed to `trainer_fn`
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if self.strategy.launcher is not None:
|
|
||||||
return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
|
|
||||||
else:
|
|
||||||
return trainer_fn(*args, **kwargs)
|
|
||||||
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
|
|
||||||
except KeyboardInterrupt as exception:
|
|
||||||
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
|
|
||||||
# user could press Ctrl+c many times... only shutdown once
|
|
||||||
if not self.interrupted:
|
|
||||||
self.state.status = TrainerStatus.INTERRUPTED
|
|
||||||
self._call_callback_hooks("on_exception", exception)
|
|
||||||
for logger in self.loggers:
|
|
||||||
logger.finalize("failed")
|
|
||||||
except BaseException as exception:
|
|
||||||
self.state.status = TrainerStatus.INTERRUPTED
|
|
||||||
if distributed_available() and self.world_size > 1:
|
|
||||||
# try syncing remaining processes, kill otherwise
|
|
||||||
self.strategy.reconciliate_processes(traceback.format_exc())
|
|
||||||
self._call_callback_hooks("on_exception", exception)
|
|
||||||
for logger in self.loggers:
|
|
||||||
logger.finalize("failed")
|
|
||||||
self._teardown()
|
|
||||||
# teardown might access the stage so we reset it after
|
|
||||||
self.state.stage = None
|
|
||||||
raise
|
|
||||||
|
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
model: "pl.LightningModule",
|
model: "pl.LightningModule",
|
||||||
|
@ -686,8 +648,8 @@ class Trainer(
|
||||||
if not isinstance(model, pl.LightningModule):
|
if not isinstance(model, pl.LightningModule):
|
||||||
raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
||||||
self.strategy._lightning_module = model
|
self.strategy._lightning_module = model
|
||||||
self._call_and_handle_interrupt(
|
teardown.call_and_handle_interrupt(
|
||||||
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
|
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
|
||||||
)
|
)
|
||||||
|
|
||||||
def _fit_impl(
|
def _fit_impl(
|
||||||
|
@ -766,7 +728,9 @@ class Trainer(
|
||||||
if model is not None and not isinstance(model, pl.LightningModule):
|
if model is not None and not isinstance(model, pl.LightningModule):
|
||||||
raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
||||||
self.strategy._lightning_module = model or self.lightning_module
|
self.strategy._lightning_module = model or self.lightning_module
|
||||||
return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule)
|
return teardown.call_and_handle_interrupt(
|
||||||
|
self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule
|
||||||
|
)
|
||||||
|
|
||||||
def _validate_impl(
|
def _validate_impl(
|
||||||
self,
|
self,
|
||||||
|
@ -856,7 +820,9 @@ class Trainer(
|
||||||
if model is not None and not isinstance(model, pl.LightningModule):
|
if model is not None and not isinstance(model, pl.LightningModule):
|
||||||
raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
||||||
self.strategy._lightning_module = model or self.lightning_module
|
self.strategy._lightning_module = model or self.lightning_module
|
||||||
return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
|
return teardown.call_and_handle_interrupt(
|
||||||
|
self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
|
||||||
|
)
|
||||||
|
|
||||||
def _test_impl(
|
def _test_impl(
|
||||||
self,
|
self,
|
||||||
|
@ -945,8 +911,8 @@ class Trainer(
|
||||||
if model is not None and not isinstance(model, pl.LightningModule):
|
if model is not None and not isinstance(model, pl.LightningModule):
|
||||||
raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
||||||
self.strategy._lightning_module = model or self.lightning_module
|
self.strategy._lightning_module = model or self.lightning_module
|
||||||
return self._call_and_handle_interrupt(
|
return teardown.call_and_handle_interrupt(
|
||||||
self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
|
self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
|
||||||
)
|
)
|
||||||
|
|
||||||
def _predict_impl(
|
def _predict_impl(
|
||||||
|
|
Loading…
Reference in New Issue