diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 206951e608..af4a6be4c2 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -82,6 +82,8 @@ class AcceleratorConnector: # NVIDIA setup self.trainer.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids) + self.trainer.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') + def select_accelerator(self): # SLURM ddp use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 898fbc92cb..f7c53c1cbe 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -345,10 +345,3 @@ class TrainerDataLoadingMixin(ABC): hvd.join() return dataloader - - def determine_data_use_amount(self, overfit_batches: float) -> None: - """Use less data for debugging purposes""" - if overfit_batches > 0: - self.limit_train_batches = overfit_batches - self.limit_val_batches = overfit_batches - self.limit_test_batches = overfit_batches diff --git a/pytorch_lightning/trainer/debugging_connector.py b/pytorch_lightning/trainer/debugging_connector.py new file mode 100644 index 0000000000..49b32b0903 --- /dev/null +++ b/pytorch_lightning/trainer/debugging_connector.py @@ -0,0 +1,109 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from typing import Union +from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info + + +class DebuggingConnector: + + def __init__(self, trainer): + self.trainer = trainer + + def on_init_start( + self, + overfit_pct, + val_percent_check, + test_percent_check, + train_percent_check, + limit_train_batches, + limit_val_batches, + limit_test_batches, + val_check_interval, + overfit_batches, + fast_dev_run + ): + self.trainer.fast_dev_run = fast_dev_run + if self.trainer.fast_dev_run: + limit_train_batches = 1 + limit_val_batches = 1 + limit_test_batches = 1 + self.trainer.num_sanity_val_steps = 0 + self.trainer.max_epochs = 1 + rank_zero_info( + 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch' + ) + + # how much of the data to use + # TODO: remove in 0.10.0 + if overfit_pct is not None: + rank_zero_warn( + "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) + overfit_batches = overfit_pct + + # TODO: remove in 0.10.0 + if val_percent_check is not None: + rank_zero_warn( + "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) + limit_val_batches = val_percent_check + + # TODO: remove in 0.10.0 + if test_percent_check is not None: + rank_zero_warn( + "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) + limit_test_batches = test_percent_check + + # TODO: remove in 0.10.0 + if train_percent_check is not None: + rank_zero_warn( + "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) + limit_train_batches = train_percent_check + + self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') + self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches') + self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches') + self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval') + self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') + self.determine_data_use_amount(self.trainer.overfit_batches) + + def determine_data_use_amount(self, overfit_batches: float) -> None: + """Use less data for debugging purposes""" + if overfit_batches > 0: + self.trainer.limit_train_batches = overfit_batches + self.trainer.limit_val_batches = overfit_batches + self.trainer.limit_test_batches = overfit_batches + + +def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: + if 0 <= batches <= 1: + return batches + elif batches > 1 and batches % 1.0 == 0: + return int(batches) + else: + raise MisconfigurationException( + f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.' + ) diff --git a/pytorch_lightning/trainer/initializer.py b/pytorch_lightning/trainer/precision_connector.py similarity index 84% rename from pytorch_lightning/trainer/initializer.py rename to pytorch_lightning/trainer/precision_connector.py index b2a39056e1..55fb945caf 100644 --- a/pytorch_lightning/trainer/initializer.py +++ b/pytorch_lightning/trainer/precision_connector.py @@ -11,16 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from pytorch_lightning import _logger as log from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType -class Initializer: +class PrecisionConnector: def __init__(self, trainer): self.trainer = trainer + def on_trainer_init(self, precision, amp_level, amp_backend): + # AMP init + # These are the only lines needed after v0.8.0 + # we wrap the user's forward with autocast and give it back at the end of fit + self.trainer.autocast_original_forward = None + self.trainer.precision = precision + self.trainer.scaler = None + + self.trainer.amp_level = amp_level + self.init_amp(amp_backend) + def init_amp(self, amp_type: str): assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported' self.trainer.amp_backend = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ef24759ce8..c4b3307937 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -43,15 +43,16 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.training_loop import TrainLoop -from pytorch_lightning.trainer.data_connector import DataConnector from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.logger_connector import LoggerConnector from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector from pytorch_lightning.trainer.callback_connector import CallbackConnector from pytorch_lightning.trainer.model_connector import ModelConnector +from pytorch_lightning.trainer.debugging_connector import DebuggingConnector from pytorch_lightning import _logger as log from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.trainer.initializer import Initializer +from pytorch_lightning.trainer.precision_connector import PrecisionConnector +from pytorch_lightning.trainer.data_connector import DataConnector from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer import docstrings from pytorch_lightning.trainer.properties import TrainerProperties @@ -173,8 +174,10 @@ class Trainer( self.accelerator_connector = AcceleratorConnector(self) self.logger_connector = LoggerConnector(self) self.model_connector = ModelConnector(self) - self.initializer = Initializer(self) + self.precision_connector = PrecisionConnector(self) self.callback_connector = CallbackConnector(self) + self.debugging_connector = DebuggingConnector(self) + self.tuner = Tuner(self) self.accelerator_backend = None @@ -253,15 +256,8 @@ class Trainer( # ------------------- self.weights_summary = weights_summary - self.max_epochs = max_epochs - self.min_epochs = min_epochs - self.max_steps = max_steps - self.min_steps = min_steps - - if num_sanity_val_steps == -1: - self.num_sanity_val_steps = float('inf') - else: - self.num_sanity_val_steps = num_sanity_val_steps + # init train loop related flags + self.train_loop.on_init_start(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch @@ -275,17 +271,6 @@ class Trainer( self.terminate_on_nan = terminate_on_nan self.shown_warnings = set() - self.fast_dev_run = fast_dev_run - if self.fast_dev_run: - limit_train_batches = 1 - limit_val_batches = 1 - limit_test_batches = 1 - self.num_sanity_val_steps = 0 - self.max_epochs = 1 - rank_zero_info( - 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch' - ) - # configure profiler if profiler is True: profiler = SimpleProfiler() @@ -300,61 +285,22 @@ class Trainer( self.log_save_interval = log_save_interval self.row_log_interval = row_log_interval - # how much of the data to use - # TODO: remove in 0.10.0 - if overfit_pct is not None: - rank_zero_warn( - "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", - DeprecationWarning, - ) - overfit_batches = overfit_pct + # init debugging flags + self.debugging_connector.on_init_start( + overfit_pct, + val_percent_check, + test_percent_check, + train_percent_check, + limit_train_batches, + limit_val_batches, + limit_test_batches, + val_check_interval, + overfit_batches, + fast_dev_run + ) - # TODO: remove in 0.10.0 - if val_percent_check is not None: - rank_zero_warn( - "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", - DeprecationWarning, - ) - limit_val_batches = val_percent_check - - # TODO: remove in 0.10.0 - if test_percent_check is not None: - rank_zero_warn( - "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", - DeprecationWarning, - ) - limit_test_batches = test_percent_check - - # TODO: remove in 0.10.0 - if train_percent_check is not None: - rank_zero_warn( - "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", - DeprecationWarning, - ) - limit_train_batches = train_percent_check - - self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') - self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches') - self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches') - self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval') - self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') - self.determine_data_use_amount(self.overfit_batches) - - # AMP init - # These are the only lines needed after v0.8.0 - # we wrap the user's forward with autocast and give it back at the end of fit - self.autocast_original_forward = None - self.precision = precision - self.scaler = None - - self.amp_level = amp_level - self.initializer.init_amp(amp_backend) - - self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') + # set precision + self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) # Callback system self.on_init_end() @@ -862,18 +808,6 @@ class Trainer( return output - -def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: - if 0 <= batches <= 1: - return batches - elif batches > 1 and batches % 1.0 == 0: - return int(batches) - else: - raise MisconfigurationException( - f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.' - ) - - # add docstrings Trainer.__init__.__doc__ = docstrings.trainer.init Trainer.fit.__doc__ = docstrings.trainer.fit diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ef0d2074b2..f62e91c11b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -38,6 +38,17 @@ class TrainLoop: self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) + def on_init_start(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps): + self.trainer.max_epochs = max_epochs + self.trainer.min_epochs = min_epochs + self.trainer.max_steps = max_steps + self.trainer.min_steps = min_steps + + if num_sanity_val_steps == -1: + self.trainer.num_sanity_val_steps = float('inf') + else: + self.trainer.num_sanity_val_steps = num_sanity_val_steps + @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable())