remove trainer hidden state | sanity refactor [2 / n] (#7507)
This commit is contained in:
parent
d0081778f8
commit
6e6e29af49
|
@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Refactored Loops
|
||||
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
||||
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
|
||||
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))
|
||||
|
||||
- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))
|
||||
|
||||
|
|
|
@ -1635,7 +1635,7 @@ class LightningModule(
|
|||
module_tbptt_enabled = self.truncated_bptt_steps > 0
|
||||
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
|
||||
if module_tbptt_enabled or trainer_tbptt_enabled:
|
||||
tqdm_dict["split_idx"] = self.trainer.split_idx
|
||||
tqdm_dict["split_idx"] = self.trainer.train_loop.split_idx
|
||||
|
||||
if self.trainer.logger is not None and self.trainer.logger.version is not None:
|
||||
version = self.trainer.logger.version
|
||||
|
|
|
@ -24,8 +24,9 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
|
|||
|
||||
class DataConnector(object):
|
||||
|
||||
def __init__(self, trainer):
|
||||
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
|
||||
self.trainer = trainer
|
||||
self.multiple_trainloader_mode = multiple_trainloader_mode
|
||||
|
||||
def on_trainer_init(
|
||||
self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool
|
||||
|
|
|
@ -126,7 +126,6 @@ class LoggerConnector:
|
|||
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
|
||||
self.trainer.log_every_n_steps = log_every_n_steps
|
||||
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
|
||||
self.trainer.split_idx = None
|
||||
|
||||
@property
|
||||
def should_flush_logs(self):
|
||||
|
|
|
@ -259,7 +259,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
|
||||
|
||||
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
|
||||
self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)
|
||||
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)
|
||||
|
||||
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
|
||||
|
||||
|
|
|
@ -313,7 +313,7 @@ class Trainer(
|
|||
# init connectors
|
||||
self.dev_debugger = InternalDebugger(self)
|
||||
self.config_validator = ConfigValidator(self)
|
||||
self.data_connector = DataConnector(self)
|
||||
self.data_connector = DataConnector(self, multiple_trainloader_mode)
|
||||
self.optimizer_connector = OptimizerConnector(self)
|
||||
|
||||
self.accelerator_connector = AcceleratorConnector(
|
||||
|
@ -329,9 +329,7 @@ class Trainer(
|
|||
self.checkpoint_connector = CheckpointConnector(self)
|
||||
self.slurm_connector = SLURMConnector(self)
|
||||
self.tuner = Tuner(self)
|
||||
self.train_loop = TrainLoop(
|
||||
self, multiple_trainloader_mode, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps
|
||||
)
|
||||
self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)
|
||||
self.evaluation_loop = EvaluationLoop(self)
|
||||
self.predict_loop = PredictLoop(self)
|
||||
|
||||
|
|
|
@ -40,7 +40,6 @@ class TrainLoop:
|
|||
def __init__(
|
||||
self,
|
||||
trainer,
|
||||
multiple_trainloader_mode: str,
|
||||
max_epochs: Optional[int],
|
||||
min_epochs: Optional[int],
|
||||
max_steps: Optional[int],
|
||||
|
@ -52,17 +51,21 @@ class TrainLoop:
|
|||
self.warning_cache = WarningCache()
|
||||
self._teardown_already_run = False
|
||||
self.running_loss = TensorRunningAccum(window_length=20)
|
||||
self._multiple_trainloader_mode = multiple_trainloader_mode
|
||||
self._skip_backward = False
|
||||
self.trainer._multiple_trainloader_mode = multiple_trainloader_mode
|
||||
self._optimizer_freq_cumsum = None
|
||||
self._hiddens = None
|
||||
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.trainer.should_stop = False
|
||||
|
||||
# the total batch index across all epochs
|
||||
self.total_batch_idx = 0
|
||||
# the current batch index in the loop that runs over the dataloader(s)
|
||||
self.batch_idx = 0
|
||||
# the current split index when the batch gets split into chunks in truncated backprop through time
|
||||
self.split_idx = None
|
||||
|
||||
self.trainer.num_training_batches = 0
|
||||
self.trainer.train_dataloader = None
|
||||
|
||||
|
@ -337,7 +340,7 @@ class TrainLoop:
|
|||
|
||||
# map to results under the hood
|
||||
result.minimize = loss
|
||||
self.trainer.hiddens = hiddens
|
||||
self._hiddens = hiddens
|
||||
|
||||
# track batch for manual reduction with result
|
||||
result.track_batch_size(len(split_batch))
|
||||
|
@ -479,7 +482,6 @@ class TrainLoop:
|
|||
|
||||
for batch_idx, (batch, is_last_batch) in train_dataloader:
|
||||
self.batch_idx = batch_idx
|
||||
self.trainer.is_last_batch = is_last_batch
|
||||
|
||||
# ------------------------------------
|
||||
# TRAINING_STEP + TRAINING_STEP_END
|
||||
|
@ -656,7 +658,7 @@ class TrainLoop:
|
|||
grad_norm_dict = {}
|
||||
|
||||
# bookkeeping
|
||||
self.trainer.hiddens = None
|
||||
self._hiddens = None
|
||||
|
||||
optimizers = self.prepare_optimizers()
|
||||
|
||||
|
@ -685,6 +687,7 @@ class TrainLoop:
|
|||
splits = self._tbptt_split_batch(batch)
|
||||
|
||||
for split_idx, split_batch in enumerate(splits):
|
||||
self.split_idx = split_idx
|
||||
|
||||
# create an iterable for optimizers and loop over them
|
||||
for opt_idx, optimizer in optimizers:
|
||||
|
@ -703,9 +706,7 @@ class TrainLoop:
|
|||
# automatic_optimization=True: perform dpp sync only when performing optimizer_step
|
||||
# automatic_optimization=False: don't block synchronization here
|
||||
with self.block_ddp_sync_behaviour():
|
||||
result = self.training_step_and_backward(
|
||||
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
|
||||
)
|
||||
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)
|
||||
|
||||
# ------------------------------
|
||||
# BACKWARD PASS
|
||||
|
@ -717,7 +718,7 @@ class TrainLoop:
|
|||
def train_step_and_backward_closure():
|
||||
nonlocal result
|
||||
result = self.training_step_and_backward(
|
||||
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
|
||||
split_batch, batch_idx, opt_idx, optimizer, self._hiddens
|
||||
)
|
||||
return None if result is None else result.loss
|
||||
|
||||
|
@ -725,7 +726,7 @@ class TrainLoop:
|
|||
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
|
||||
|
||||
else:
|
||||
result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens)
|
||||
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
|
||||
|
||||
if not result:
|
||||
# user decided to skip optimization
|
||||
|
@ -968,9 +969,6 @@ class TrainLoop:
|
|||
return optimizers
|
||||
|
||||
def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):
|
||||
# set split_idx to trainer for tracking
|
||||
self.trainer.split_idx = split_idx
|
||||
|
||||
# make sure only the gradients of the current optimizer's parameters are calculated
|
||||
# in the training step to prevent dangling gradients in multiple-optimizer setup.
|
||||
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
|
||||
|
|
|
@ -81,7 +81,11 @@ def test__eval_step__flow(tmpdir):
|
|||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
|
||||
batch,
|
||||
batch_idx,
|
||||
0,
|
||||
trainer.optimizers[0],
|
||||
hiddens=None,
|
||||
)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
||||
|
@ -150,7 +154,7 @@ def test__eval_step__eval_step_end__flow(tmpdir):
|
|||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
|
||||
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
|
||||
)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
||||
|
|
|
@ -165,7 +165,11 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
|
|||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
|
||||
batch,
|
||||
batch_idx,
|
||||
0,
|
||||
trainer.optimizers[0],
|
||||
hiddens=None,
|
||||
)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
||||
|
@ -241,7 +245,7 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
|
|||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
|
||||
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
|
||||
)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
||||
|
|
Loading…
Reference in New Issue