remove trainer hidden state | sanity refactor [2 / n] (#7507)

This commit is contained in:
Adrian Wälchli 2021-05-17 09:57:15 +02:00 committed by GitHub
parent d0081778f8
commit 6e6e29af49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 31 additions and 26 deletions

View File

@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Refactored Loops
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))
- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))

View File

@ -1635,7 +1635,7 @@ class LightningModule(
module_tbptt_enabled = self.truncated_bptt_steps > 0
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
if module_tbptt_enabled or trainer_tbptt_enabled:
tqdm_dict["split_idx"] = self.trainer.split_idx
tqdm_dict["split_idx"] = self.trainer.train_loop.split_idx
if self.trainer.logger is not None and self.trainer.logger.version is not None:
version = self.trainer.logger.version

View File

@ -24,8 +24,9 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
class DataConnector(object):
def __init__(self, trainer):
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
def on_trainer_init(
self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool

View File

@ -126,7 +126,6 @@ class LoggerConnector:
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
self.trainer.split_idx = None
@property
def should_flush_logs(self):

View File

@ -259,7 +259,7 @@ class TrainerDataLoadingMixin(ABC):
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')

View File

@ -313,7 +313,7 @@ class Trainer(
# init connectors
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self)
self.data_connector = DataConnector(self, multiple_trainloader_mode)
self.optimizer_connector = OptimizerConnector(self)
self.accelerator_connector = AcceleratorConnector(
@ -329,9 +329,7 @@ class Trainer(
self.checkpoint_connector = CheckpointConnector(self)
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)
self.train_loop = TrainLoop(
self, multiple_trainloader_mode, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps
)
self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)
self.evaluation_loop = EvaluationLoop(self)
self.predict_loop = PredictLoop(self)

View File

@ -40,7 +40,6 @@ class TrainLoop:
def __init__(
self,
trainer,
multiple_trainloader_mode: str,
max_epochs: Optional[int],
min_epochs: Optional[int],
max_steps: Optional[int],
@ -52,17 +51,21 @@ class TrainLoop:
self.warning_cache = WarningCache()
self._teardown_already_run = False
self.running_loss = TensorRunningAccum(window_length=20)
self._multiple_trainloader_mode = multiple_trainloader_mode
self._skip_backward = False
self.trainer._multiple_trainloader_mode = multiple_trainloader_mode
self._optimizer_freq_cumsum = None
self._hiddens = None
self.global_step = 0
self.current_epoch = 0
self.trainer.should_stop = False
# the total batch index across all epochs
self.total_batch_idx = 0
# the current batch index in the loop that runs over the dataloader(s)
self.batch_idx = 0
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx = None
self.trainer.num_training_batches = 0
self.trainer.train_dataloader = None
@ -337,7 +340,7 @@ class TrainLoop:
# map to results under the hood
result.minimize = loss
self.trainer.hiddens = hiddens
self._hiddens = hiddens
# track batch for manual reduction with result
result.track_batch_size(len(split_batch))
@ -479,7 +482,6 @@ class TrainLoop:
for batch_idx, (batch, is_last_batch) in train_dataloader:
self.batch_idx = batch_idx
self.trainer.is_last_batch = is_last_batch
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
@ -656,7 +658,7 @@ class TrainLoop:
grad_norm_dict = {}
# bookkeeping
self.trainer.hiddens = None
self._hiddens = None
optimizers = self.prepare_optimizers()
@ -685,6 +687,7 @@ class TrainLoop:
splits = self._tbptt_split_batch(batch)
for split_idx, split_batch in enumerate(splits):
self.split_idx = split_idx
# create an iterable for optimizers and loop over them
for opt_idx, optimizer in optimizers:
@ -703,9 +706,7 @@ class TrainLoop:
# automatic_optimization=True: perform dpp sync only when performing optimizer_step
# automatic_optimization=False: don't block synchronization here
with self.block_ddp_sync_behaviour():
result = self.training_step_and_backward(
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
)
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)
# ------------------------------
# BACKWARD PASS
@ -717,7 +718,7 @@ class TrainLoop:
def train_step_and_backward_closure():
nonlocal result
result = self.training_step_and_backward(
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
split_batch, batch_idx, opt_idx, optimizer, self._hiddens
)
return None if result is None else result.loss
@ -725,7 +726,7 @@ class TrainLoop:
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
else:
result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens)
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
if not result:
# user decided to skip optimization
@ -968,9 +969,6 @@ class TrainLoop:
return optimizers
def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):
# set split_idx to trainer for tracking
self.trainer.split_idx = split_idx
# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:

View File

@ -81,7 +81,11 @@ def test__eval_step__flow(tmpdir):
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
batch,
batch_idx,
0,
trainer.optimizers[0],
hiddens=None,
)
assert opt_closure_result['loss'].item() == 171
@ -150,7 +154,7 @@ def test__eval_step__eval_step_end__flow(tmpdir):
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
assert opt_closure_result['loss'].item() == 171

View File

@ -165,7 +165,11 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
batch,
batch_idx,
0,
trainer.optimizers[0],
hiddens=None,
)
assert opt_closure_result['loss'].item() == 171
@ -241,7 +245,7 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
assert opt_closure_result['loss'].item() == 171