From 6e6e29af49df7bf0e27e8197ccd3b9c3e3adaba6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 May 2021 09:57:15 +0200 Subject: [PATCH] remove trainer hidden state | sanity refactor [2 / n] (#7507) --- CHANGELOG.md | 1 + pytorch_lightning/core/lightning.py | 2 +- .../trainer/connectors/data_connector.py | 3 ++- .../logger_connector/logger_connector.py | 1 - pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/trainer.py | 6 ++--- pytorch_lightning/trainer/training_loop.py | 26 +++++++++---------- .../loops/test_evaluation_loop_flow.py | 8 ++++-- .../loops/test_training_loop_flow_scalar.py | 8 ++++-- 9 files changed, 31 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fea7a41400..0173cc8d25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored Loops * Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) * Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506)) + * Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507)) - `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b5306c6041..f1269e4ef1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1635,7 +1635,7 @@ class LightningModule( module_tbptt_enabled = self.truncated_bptt_steps > 0 trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0 if module_tbptt_enabled or trainer_tbptt_enabled: - tqdm_dict["split_idx"] = self.trainer.split_idx + tqdm_dict["split_idx"] = self.trainer.train_loop.split_idx if self.trainer.logger is not None and self.trainer.logger.version is not None: version = self.trainer.logger.version diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9fb531f8eb..a867bf96a8 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -24,8 +24,9 @@ from pytorch_lightning.utilities.model_helpers import is_overridden class DataConnector(object): - def __init__(self, trainer): + def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer + self.multiple_trainloader_mode = multiple_trainloader_mode def on_trainer_init( self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 1c82985576..8cf47b5848 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -126,7 +126,6 @@ class LoggerConnector: self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps self.trainer.log_every_n_steps = log_every_n_steps self.trainer.move_metrics_to_cpu = move_metrics_to_cpu - self.trainer.split_idx = None @property def should_flush_logs(self): diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 29711b23d8..a20dca5084 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -259,7 +259,7 @@ class TrainerDataLoadingMixin(ABC): apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches - self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode) + self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode) self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a9a431ddbb..6c0d062387 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -313,7 +313,7 @@ class Trainer( # init connectors self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) - self.data_connector = DataConnector(self) + self.data_connector = DataConnector(self, multiple_trainloader_mode) self.optimizer_connector = OptimizerConnector(self) self.accelerator_connector = AcceleratorConnector( @@ -329,9 +329,7 @@ class Trainer( self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - self.train_loop = TrainLoop( - self, multiple_trainloader_mode, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps - ) + self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) self.evaluation_loop = EvaluationLoop(self) self.predict_loop = PredictLoop(self) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f7da8e929e..1184be7850 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -40,7 +40,6 @@ class TrainLoop: def __init__( self, trainer, - multiple_trainloader_mode: str, max_epochs: Optional[int], min_epochs: Optional[int], max_steps: Optional[int], @@ -52,17 +51,21 @@ class TrainLoop: self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) - self._multiple_trainloader_mode = multiple_trainloader_mode self._skip_backward = False - self.trainer._multiple_trainloader_mode = multiple_trainloader_mode self._optimizer_freq_cumsum = None + self._hiddens = None self.global_step = 0 self.current_epoch = 0 self.trainer.should_stop = False + # the total batch index across all epochs self.total_batch_idx = 0 + # the current batch index in the loop that runs over the dataloader(s) self.batch_idx = 0 + # the current split index when the batch gets split into chunks in truncated backprop through time + self.split_idx = None + self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None @@ -337,7 +340,7 @@ class TrainLoop: # map to results under the hood result.minimize = loss - self.trainer.hiddens = hiddens + self._hiddens = hiddens # track batch for manual reduction with result result.track_batch_size(len(split_batch)) @@ -479,7 +482,6 @@ class TrainLoop: for batch_idx, (batch, is_last_batch) in train_dataloader: self.batch_idx = batch_idx - self.trainer.is_last_batch = is_last_batch # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END @@ -656,7 +658,7 @@ class TrainLoop: grad_norm_dict = {} # bookkeeping - self.trainer.hiddens = None + self._hiddens = None optimizers = self.prepare_optimizers() @@ -685,6 +687,7 @@ class TrainLoop: splits = self._tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): + self.split_idx = split_idx # create an iterable for optimizers and loop over them for opt_idx, optimizer in optimizers: @@ -703,9 +706,7 @@ class TrainLoop: # automatic_optimization=True: perform dpp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): - result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens - ) + self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self._hiddens) # ------------------------------ # BACKWARD PASS @@ -717,7 +718,7 @@ class TrainLoop: def train_step_and_backward_closure(): nonlocal result result = self.training_step_and_backward( - split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + split_batch, batch_idx, opt_idx, optimizer, self._hiddens ) return None if result is None else result.loss @@ -725,7 +726,7 @@ class TrainLoop: self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) else: - result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens) if not result: # user decided to skip optimization @@ -968,9 +969,6 @@ class TrainLoop: return optimizers def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): - # set split_idx to trainer for tracking - self.trainer.split_idx = split_idx - # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 3177a3aa09..67ed756630 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -81,7 +81,11 @@ def test__eval_step__flow(tmpdir): # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( - batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens + batch, + batch_idx, + 0, + trainer.optimizers[0], + hiddens=None, ) assert opt_closure_result['loss'].item() == 171 @@ -150,7 +154,7 @@ def test__eval_step__eval_step_end__flow(tmpdir): # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( - batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens + batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index f14f7d339d..2f503b62f5 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -165,7 +165,11 @@ def test__training_step__epoch_end__flow_scalar(tmpdir): # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( - batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens + batch, + batch_idx, + 0, + trainer.optimizers[0], + hiddens=None, ) assert opt_closure_result['loss'].item() == 171 @@ -241,7 +245,7 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir): # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( - batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens + batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171