From 047d7088f4366e68516d017f6e2b81f61bdef13a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 23 Sep 2022 11:10:42 -0400 Subject: [PATCH] Make Trainer readable and debuggable (3/n) (#14871) * clean trainer 3/n * clean trainer 3/n * clean trainer 3/n * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean trainer 3/n Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/pytorch_lightning/trainer/setup.py | 204 ++++++++++++++++++++++ src/pytorch_lightning/trainer/trainer.py | 208 ++--------------------- 2 files changed, 220 insertions(+), 192 deletions(-) create mode 100644 src/pytorch_lightning/trainer/setup.py diff --git a/src/pytorch_lightning/trainer/setup.py b/src/pytorch_lightning/trainer/setup.py new file mode 100644 index 0000000000..3fc7ba6cc7 --- /dev/null +++ b/src/pytorch_lightning/trainer/setup.py @@ -0,0 +1,204 @@ +# Copyright Lightning AI. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Houses the methods used to set up the Trainer.""" + +from typing import Any, Optional, Union + +from lightning_lite.utilities.warnings import PossibleUserWarning +from pytorch_lightning.accelerators import ( + CUDAAccelerator, + HPUAccelerator, + IPUAccelerator, + MPSAccelerator, + TPUAccelerator, +) +from pytorch_lightning.loggers.logger import DummyLogger +from pytorch_lightning.profilers import ( + AdvancedProfiler, + PassThroughProfiler, + Profiler, + PyTorchProfiler, + SimpleProfiler, + XLAProfiler, +) +from pytorch_lightning.utilities import _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn + + +def init_debugging_flags( + trainer: Any, + limit_train_batches: Optional[Union[int, float]], + limit_val_batches: Optional[Union[int, float]], + limit_test_batches: Optional[Union[int, float]], + limit_predict_batches: Optional[Union[int, float]], + fast_dev_run: Union[int, bool], + overfit_batches: Union[int, float], + val_check_interval: Optional[Union[int, float]], + num_sanity_val_steps: int, +) -> None: + # init debugging flags + if isinstance(fast_dev_run, int) and (fast_dev_run < 0): + raise MisconfigurationException( + f"fast_dev_run={fast_dev_run!r} is not a valid configuration. It should be >= 0." + ) + trainer.fast_dev_run = fast_dev_run + + # set fast_dev_run=True when it is 1, used while logging + if fast_dev_run == 1: + trainer.fast_dev_run = True + + trainer.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches") + overfit_batches_enabled = overfit_batches > 0 + + if fast_dev_run: + num_batches = int(fast_dev_run) + if not overfit_batches_enabled: + trainer.limit_train_batches = num_batches + trainer.limit_val_batches = num_batches + + trainer.limit_test_batches = num_batches + trainer.limit_predict_batches = num_batches + trainer.fit_loop.max_steps = num_batches + trainer.num_sanity_val_steps = 0 + trainer.fit_loop.max_epochs = 1 + trainer.val_check_interval = 1.0 + trainer.check_val_every_n_epoch = 1 + trainer.loggers = [DummyLogger()] if trainer.loggers else [] + rank_zero_info( + f"Running in `fast_dev_run` mode: will run the requested loop using {num_batches} batch(es). " + "Logging and checkpointing is suppressed." + ) + else: + if not overfit_batches_enabled: + trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches") + trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches") + trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") + trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") + trainer.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps + trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") + + if overfit_batches_enabled: + trainer.limit_train_batches = overfit_batches + trainer.limit_val_batches = overfit_batches + + +def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: + if batches is None: + # batches is optional to know if the user passed a value so that we can show the above info messages only to the + # users that set a value explicitly + return 1.0 + + # differentiating based on the type can be error-prone for users. show a message describing the chosen behaviour + if isinstance(batches, int) and batches == 1: + if name == "limit_train_batches": + message = "1 batch per epoch will be used." + elif name == "val_check_interval": + message = "validation will run after every batch." + else: + message = "1 batch will be used." + rank_zero_info(f"`Trainer({name}=1)` was configured so {message}") + elif isinstance(batches, float) and batches == 1.0: + if name == "limit_train_batches": + message = "100% of the batches per epoch will be used." + elif name == "val_check_interval": + message = "validation will run at the end of the training epoch." + else: + message = "100% of the batches will be used." + rank_zero_info(f"`Trainer({name}=1.0)` was configured so {message}.") + + if 0 <= batches <= 1: + return batches + if batches > 1 and batches % 1.0 == 0: + return int(batches) + raise MisconfigurationException( + f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int." + ) + + +def init_profiler(trainer: Any, profiler: Optional[Union[Profiler, str]]) -> None: + if isinstance(profiler, str): + PROFILERS = { + "simple": SimpleProfiler, + "advanced": AdvancedProfiler, + "pytorch": PyTorchProfiler, + "xla": XLAProfiler, + } + profiler = profiler.lower() + if profiler not in PROFILERS: + raise MisconfigurationException( + "When passing string value for the `profiler` parameter of `Trainer`," + f" it can only be one of {list(PROFILERS.keys())}" + ) + profiler_class = PROFILERS[profiler] + profiler = profiler_class() + trainer.profiler = profiler or PassThroughProfiler() + + +def log_device_info(trainer: Any) -> None: + + if CUDAAccelerator.is_available(): + gpu_available = True + gpu_type = " (cuda)" + elif MPSAccelerator.is_available(): + gpu_available = True + gpu_type = " (mps)" + else: + gpu_available = False + gpu_type = "" + + gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) + rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") + + num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, TPUAccelerator) else 0 + rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores") + + num_ipus = trainer.num_devices if isinstance(trainer.accelerator, IPUAccelerator) else 0 + rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs") + + num_hpus = trainer.num_devices if isinstance(trainer.accelerator, HPUAccelerator) else 0 + rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs") + + # TODO: Integrate MPS Accelerator here, once gpu maps to both + if CUDAAccelerator.is_available() and not isinstance(trainer.accelerator, CUDAAccelerator): + rank_zero_warn( + "GPU available but not used. Set `accelerator` and `devices` using" + f" `Trainer(accelerator='gpu', devices={CUDAAccelerator.auto_device_count()})`.", + category=PossibleUserWarning, + ) + + if _TPU_AVAILABLE and not isinstance(trainer.accelerator, TPUAccelerator): + rank_zero_warn( + "TPU available but not used. Set `accelerator` and `devices` using" + f" `Trainer(accelerator='tpu', devices={TPUAccelerator.auto_device_count()})`." + ) + + if _IPU_AVAILABLE and not isinstance(trainer.accelerator, IPUAccelerator): + rank_zero_warn( + "IPU available but not used. Set `accelerator` and `devices` using" + f" `Trainer(accelerator='ipu', devices={IPUAccelerator.auto_device_count()})`." + ) + + if _HPU_AVAILABLE and not isinstance(trainer.accelerator, HPUAccelerator): + rank_zero_warn( + "HPU available but not used. Set `accelerator` and `devices` using" + f" `Trainer(accelerator='hpu', devices={HPUAccelerator.auto_device_count()})`." + ) + + if MPSAccelerator.is_available() and not isinstance(trainer.accelerator, MPSAccelerator): + rank_zero_warn( + "MPS available but not used. Set `accelerator` and `devices` using" + f" `Trainer(accelerator='mps', devices={MPSAccelerator.auto_device_count()})`." + ) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index c21fb226ec..c438b2f767 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -11,6 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# THIS FILE MUST READ EASILY, FOR UNDERSTANDING AND DEBUGGING PURPOSES. +# DO NOT OBSCURE THE TRAINING LOOP +# THIS IS A HARD REQUIREMENT TO CONTRIBUTING TO LIGHTNING +# WE FAVOR READABILITY OVER ENGINEERING-CONSTRUCTS BY DESIGN +# DO NOT REMOVE THIS NOTICE +# - WILLIAM FALCON + """Trainer to automate the training.""" import inspect import logging @@ -39,20 +47,12 @@ from lightning_lite.utilities.cloud_io import get_filesystem from lightning_lite.utilities.data import _auto_add_worker_init_fn from lightning_lite.utilities.types import _PATH from lightning_lite.utilities.warnings import PossibleUserWarning -from pytorch_lightning.accelerators import ( - Accelerator, - CUDAAccelerator, - HPUAccelerator, - IPUAccelerator, - MPSAccelerator, - TPUAccelerator, -) +from pytorch_lightning.accelerators import Accelerator, HPUAccelerator, TPUAccelerator from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import Logger -from pytorch_lightning.loggers.logger import DummyLogger from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.loops import PredictionLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop @@ -64,16 +64,9 @@ from pytorch_lightning.plugins import ( PLUGIN_INPUT, PrecisionPlugin, ) -from pytorch_lightning.profilers import ( - AdvancedProfiler, - PassThroughProfiler, - Profiler, - PyTorchProfiler, - SimpleProfiler, - XLAProfiler, -) +from pytorch_lightning.profilers import Profiler from pytorch_lightning.strategies import ParallelStrategy, Strategy -from pytorch_lightning.trainer import teardown +from pytorch_lightning.trainer import setup, teardown from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector @@ -87,14 +80,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.tuner.tuning import _TunerResult, Tuner -from pytorch_lightning.utilities import ( - _HPU_AVAILABLE, - _IPU_AVAILABLE, - _TPU_AVAILABLE, - AMPType, - GradClipAlgorithmType, - parsing, -) +from pytorch_lightning.utilities import AMPType, GradClipAlgorithmType, parsing from pytorch_lightning.utilities.argparse import ( _defaults_from_env_vars, add_argparse_args, @@ -523,7 +509,7 @@ class Trainer( self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size) # configure profiler - self.__init_profiler(profiler) + setup.init_profiler(self, profiler) # init logger flags self._loggers: List[Logger] @@ -531,7 +517,8 @@ class Trainer( # init debugging flags self.val_check_interval: Union[int, float] - self._init_debugging_flags( + setup.init_debugging_flags( + self, limit_train_batches, limit_val_batches, limit_test_batches, @@ -545,64 +532,8 @@ class Trainer( # Callback system self._call_callback_hooks("on_init_end") - def _init_debugging_flags( - self, - limit_train_batches: Optional[Union[int, float]], - limit_val_batches: Optional[Union[int, float]], - limit_test_batches: Optional[Union[int, float]], - limit_predict_batches: Optional[Union[int, float]], - fast_dev_run: Union[int, bool], - overfit_batches: Union[int, float], - val_check_interval: Optional[Union[int, float]], - num_sanity_val_steps: int, - ): - # init debugging flags - if isinstance(fast_dev_run, int) and (fast_dev_run < 0): - raise MisconfigurationException( - f"fast_dev_run={fast_dev_run!r} is not a valid configuration. It should be >= 0." - ) - self.fast_dev_run = fast_dev_run - - # set fast_dev_run=True when it is 1, used while logging - if fast_dev_run == 1: - self.fast_dev_run = True - - self.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches") - overfit_batches_enabled = overfit_batches > 0 - - if fast_dev_run: - num_batches = int(fast_dev_run) - if not overfit_batches_enabled: - self.limit_train_batches = num_batches - self.limit_val_batches = num_batches - - self.limit_test_batches = num_batches - self.limit_predict_batches = num_batches - self.fit_loop.max_steps = num_batches - self.num_sanity_val_steps = 0 - self.fit_loop.max_epochs = 1 - self.val_check_interval = 1.0 - self.check_val_every_n_epoch = 1 - self.loggers = [DummyLogger()] if self.loggers else [] - rank_zero_info( - f"Running in `fast_dev_run` mode: will run the requested loop using {num_batches} batch(es). " - "Logging and checkpointing is suppressed." - ) - else: - if not overfit_batches_enabled: - self.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches") - self.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches") - self.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") - self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") - self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps - self.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") - - if overfit_batches_enabled: - self.limit_train_batches = overfit_batches - self.limit_val_batches = overfit_batches - def _setup_on_init(self) -> None: - self._log_device_info() + setup.log_device_info(self) self.should_stop = False self.state = TrainerState() @@ -1591,85 +1522,11 @@ class Trainer( def _log_api_event(event: str) -> None: torch._C._log_api_usage_once("lightning.trainer." + event) - def __init_profiler(self, profiler: Optional[Union[Profiler, str]]) -> None: - if isinstance(profiler, str): - PROFILERS = { - "simple": SimpleProfiler, - "advanced": AdvancedProfiler, - "pytorch": PyTorchProfiler, - "xla": XLAProfiler, - } - profiler = profiler.lower() - if profiler not in PROFILERS: - raise MisconfigurationException( - "When passing string value for the `profiler` parameter of `Trainer`," - f" it can only be one of {list(PROFILERS.keys())}" - ) - profiler_class = PROFILERS[profiler] - profiler = profiler_class() - self.profiler: Profiler = profiler or PassThroughProfiler() - def __setup_profiler(self) -> None: local_rank = self.local_rank if self.world_size > 1 else None self.profiler._lightning_module = proxy(self.lightning_module) self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir) - def _log_device_info(self) -> None: - - if CUDAAccelerator.is_available(): - gpu_available = True - gpu_type = " (cuda)" - elif MPSAccelerator.is_available(): - gpu_available = True - gpu_type = " (mps)" - else: - gpu_available = False - gpu_type = "" - - gpu_used = isinstance(self.accelerator, (CUDAAccelerator, MPSAccelerator)) - rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") - - num_tpu_cores = self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0 - rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores") - - num_ipus = self.num_devices if isinstance(self.accelerator, IPUAccelerator) else 0 - rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs") - - num_hpus = self.num_devices if isinstance(self.accelerator, HPUAccelerator) else 0 - rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs") - - # TODO: Integrate MPS Accelerator here, once gpu maps to both - if CUDAAccelerator.is_available() and not isinstance(self.accelerator, CUDAAccelerator): - rank_zero_warn( - "GPU available but not used. Set `accelerator` and `devices` using" - f" `Trainer(accelerator='gpu', devices={CUDAAccelerator.auto_device_count()})`.", - category=PossibleUserWarning, - ) - - if _TPU_AVAILABLE and not isinstance(self.accelerator, TPUAccelerator): - rank_zero_warn( - "TPU available but not used. Set `accelerator` and `devices` using" - f" `Trainer(accelerator='tpu', devices={TPUAccelerator.auto_device_count()})`." - ) - - if _IPU_AVAILABLE and not isinstance(self.accelerator, IPUAccelerator): - rank_zero_warn( - "IPU available but not used. Set `accelerator` and `devices` using" - f" `Trainer(accelerator='ipu', devices={IPUAccelerator.auto_device_count()})`." - ) - - if _HPU_AVAILABLE and not isinstance(self.accelerator, HPUAccelerator): - rank_zero_warn( - "HPU available but not used. Set `accelerator` and `devices` using" - f" `Trainer(accelerator='hpu', devices={HPUAccelerator.auto_device_count()})`." - ) - - if MPSAccelerator.is_available() and not isinstance(self.accelerator, MPSAccelerator): - rank_zero_warn( - "MPS available but not used. Set `accelerator` and `devices` using" - f" `Trainer(accelerator='mps', devices={MPSAccelerator.auto_device_count()})`." - ) - """ Data loading methods """ @@ -2542,36 +2399,3 @@ def _evaluation_context(accelerator: Accelerator) -> Generator: ) with context_manager_class(): yield - - -def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: - if batches is None: - # batches is optional to know if the user passed a value so that we can show the above info messages only to the - # users that set a value explicitly - return 1.0 - - # differentiating based on the type can be error-prone for users. show a message describing the chosen behaviour - if isinstance(batches, int) and batches == 1: - if name == "limit_train_batches": - message = "1 batch per epoch will be used." - elif name == "val_check_interval": - message = "validation will run after every batch." - else: - message = "1 batch will be used." - rank_zero_info(f"`Trainer({name}=1)` was configured so {message}") - elif isinstance(batches, float) and batches == 1.0: - if name == "limit_train_batches": - message = "100% of the batches per epoch will be used." - elif name == "val_check_interval": - message = "validation will run at the end of the training epoch." - else: - message = "100% of the batches will be used." - rank_zero_info(f"`Trainer({name}=1.0)` was configured so {message}.") - - if 0 <= batches <= 1: - return batches - if batches > 1 and batches % 1.0 == 0: - return int(batches) - raise MisconfigurationException( - f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int." - )