From 256059a1d071a50c44340d41523861dd86c3accd Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 9 Aug 2020 06:00:15 -0400 Subject: [PATCH] tracks all outputs including TBPTT and multiple optimizers (#2890) * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update --- pytorch_lightning/core/step_result.py | 210 +++++++++++++++--- pytorch_lightning/trainer/training_loop.py | 142 ++++++++++-- tests/base/model_train_steps.py | 1 + tests/models/test_cpu.py | 155 +++++++++++++ .../trainer/test_trainer_steps_dict_return.py | 10 + .../test_trainer_steps_result_return.py | 8 + .../test_trainer_steps_scalar_return.py | 8 + 7 files changed, 485 insertions(+), 49 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 84f9669cf4..5174f8aa44 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -92,6 +92,8 @@ class Result(Dict): on_step: bool = False, on_epoch: bool = True, reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, enable_graph: bool = False, sync_ddp: bool = False, sync_ddp_op: Union[Any, str] = 'mean', @@ -113,15 +115,22 @@ class Result(Dict): if on_step and on_epoch: # set step version step_name = f'step_{name}' - self.__set_meta(step_name, value, prog_bar, logger, on_step=True, on_epoch=False, reduce_fx=reduce_fx) + self.__set_meta(step_name, value, prog_bar, logger, + on_step=True, on_epoch=False, + reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token) self.__setitem__(step_name, value) # set epoch version epoch_name = f'epoch_{name}' - self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True, reduce_fx=reduce_fx) + self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True, + reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token) self.__setitem__(epoch_name, value) else: - self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx) + self.__set_meta(name, value, + prog_bar, logger, + on_step, on_epoch, + reduce_fx, + tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token) # set the value self.__setitem__(name, value) @@ -135,6 +144,8 @@ class Result(Dict): on_step: bool, on_epoch: bool, reduce_fx: Callable, + tbptt_pad_token: int, + tbptt_reduce_fx: Callable ): # set the meta for the item meta_value = value @@ -144,7 +155,9 @@ class Result(Dict): on_step=on_step, on_epoch=on_epoch, reduce_fx=reduce_fx, - value=meta_value + value=meta_value, + tbptt_reduce_fx=tbptt_reduce_fx, + tbptt_pad_token=tbptt_pad_token ) self['meta'][name] = meta @@ -253,6 +266,39 @@ class Result(Dict): result['meta'] = meta return result + @classmethod + def padded_gather(cls, outputs): + meta = outputs[0].get('meta') + result = cls() + result = recursive_gather(outputs, result) + + # find the padding used for other values + default_padding_idx = 0 + for name, value in result.items(): + if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor): + if name not in {'checkpoint_on', 'early_stop_on', 'minimize'}: + default_padding_idx = meta[name]['tbptt_pad_token'] + break + + # pad across each key individually + for name, value in result.items(): + is_reserved = name in {'checkpoint_on', 'early_stop_on', 'minimize'} + if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor): + + if is_reserved: + padding_key = default_padding_idx + else: + padding_key = meta[name]['tbptt_pad_token'] + padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key) + result[name] = padded + + # also update the result + if meta and not is_reserved: + meta[name]['value'] = padded + if meta: + result['meta'] = meta + return result + @classmethod def reduce_on_epoch_end(cls, outputs): meta = outputs[0]['meta'] @@ -271,10 +317,36 @@ class Result(Dict): result['meta'] = meta return result + @classmethod + def reduce_across_time(cls, time_outputs): + # auto-reduce across time for tbptt + meta = time_outputs[0]['meta'] + result = cls() + result = recursive_gather(time_outputs, result) + recursive_stack(result) + + for k, value in result.items(): + if k == 'meta': + continue + + # pick the reduce fx + if k in ['checkpoint_on', 'early_stop_on', 'minimize']: + tbptt_reduce_fx = torch.mean + else: + tbptt_reduce_fx = meta[k]['tbptt_reduce_fx'] + result[k] = tbptt_reduce_fx(value) + + result['meta'] = meta + return result + @property def should_reduce_on_epoch_end(self) -> bool: return self['meta']['_internal']['_reduce_on_epoch'] + def drop_hiddens(self): + if 'hiddens' in self: + del self['hiddens'] + def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]: for out in outputs: @@ -303,6 +375,16 @@ def recursive_stack(result: MutableMapping): result[k] = v +def recursive_padded_stack(result: MutableMapping): + for k, v in result.items(): + if isinstance(v, dict): + recursive_stack(v) + + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor): + v = torch.stack(v) + result[k] = v + + class TrainResult(Result): def __init__( @@ -348,6 +430,8 @@ class TrainResult(Result): on_step: bool = True, on_epoch: bool = False, reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, enable_graph: bool = False, sync_ddp: bool = False, sync_ddp_op: Union[Any, str] = 'mean', @@ -381,10 +465,26 @@ class TrainResult(Result): on_step: if True logs the output of validation_step or test_step on_epoch: if True, logs the output of the training loop aggregated reduce_fx: Torch.mean by default + tbptt_reduce_fx: function to reduce on truncated back prop + tbptt_pad_token: token to use for padding enable_graph: if True, will not auto detach the graph + sync_ddp: if True, reduces the metric across GPUs/TPUs + sync_ddp_op: the op to sync across + sync_ddp_group: the ddp group """ - super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph, - sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op) + super().log(name=name, + value=value, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_ddp=sync_ddp, + sync_ddp_group=sync_ddp_group, + sync_ddp_op=sync_ddp_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx) def log_dict( self, @@ -394,6 +494,8 @@ class TrainResult(Result): on_step: bool = False, on_epoch: bool = True, reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, enable_graph: bool = False, sync_ddp: bool = False, sync_ddp_op: Union[Any, str] = 'mean', @@ -408,17 +510,33 @@ class TrainResult(Result): result.log_dict(values) Args: - dictionary: - prog_bar: - logger: - on_step: - on_epoch: - reduce_fx: - enable_graph: + dictionary: key value pairs (str, tensors) + prog_bar: if True logs to the progress base + logger: if True logs to the logger + on_step: if True logs the output of validation_step or test_step + on_epoch: if True, logs the output of the training loop aggregated + reduce_fx: Torch.mean by default + tbptt_reduce_fx: function to reduce on truncated back prop + tbptt_pad_token: token to use for padding + enable_graph: if True, will not auto detach the graph + sync_ddp: if True, reduces the metric across GPUs/TPUs + sync_ddp_op: the op to sync across + sync_ddp_group: the ddp group: """ for k, v in dictionary.items(): - self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph, - sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op) + self.log(name=k, + value=v, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_ddp=sync_ddp, + sync_ddp_group=sync_ddp_group, + sync_ddp_op=sync_ddp_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx) class EvalResult(Result): @@ -464,6 +582,8 @@ class EvalResult(Result): on_step: bool = False, on_epoch: bool = True, reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, enable_graph: bool = False, sync_ddp: bool = False, sync_ddp_op: Union[Any, str] = 'mean', @@ -494,12 +614,28 @@ class EvalResult(Result): prog_bar: if True logs to the progress base logger: if True logs to the logger on_step: if True logs the output of validation_step or test_step - on_epoch: if True, logs the output of the validation loop or test loop aggregated + on_epoch: if True, logs the output of the training loop aggregated reduce_fx: Torch.mean by default - enable_graph: if True, will not auto detach the graph : + tbptt_reduce_fx: function to reduce on truncated back prop + tbptt_pad_token: token to use for padding + enable_graph: if True, will not auto detach the graph + sync_ddp: if True, reduces the metric across GPUs/TPUs + sync_ddp_op: the op to sync across + sync_ddp_group: the ddp group """ - super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph, - sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op) + super().log(name=name, + value=value, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_ddp=sync_ddp, + sync_ddp_group=sync_ddp_group, + sync_ddp_op=sync_ddp_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx) def log_dict( self, @@ -509,6 +645,8 @@ class EvalResult(Result): on_step: bool = False, on_epoch: bool = True, reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, enable_graph: bool = False, sync_ddp: bool = False, sync_ddp_op: Union[Any, str] = 'mean', @@ -523,17 +661,33 @@ class EvalResult(Result): result.log_dict(values) Args: - dictionary: - prog_bar: - logger: - on_step: - on_epoch: - reduce_fx: - enable_graph: + dictionary: key value pairs (str, tensors) + prog_bar: if True logs to the progress base + logger: if True logs to the logger + on_step: if True logs the output of validation_step or test_step + on_epoch: if True, logs the output of the training loop aggregated + reduce_fx: Torch.mean by default + tbptt_reduce_fx: function to reduce on truncated back prop + tbptt_pad_token: token to use for padding + enable_graph: if True, will not auto detach the graph + sync_ddp: if True, reduces the metric across GPUs/TPUs + sync_ddp_op: the op to sync across + sync_ddp_group: the ddp group """ for k, v in dictionary.items(): - self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph, - sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op) + self.log(name=k, + value=v, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_ddp=sync_ddp, + sync_ddp_group=sync_ddp_group, + sync_ddp_op=sync_ddp_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx) def get_callback_metrics(self) -> dict: result = { diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 72178b8e8e..ea9f915a8d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -462,7 +462,8 @@ class TrainerTrainLoopMixin(ABC): train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader) # bookkeeping - epoch_output = [] + num_optimizers = len(self._get_optimizers_iterable()) + epoch_output = [[] for _ in range(num_optimizers)] should_check_val = False # structured result accumulators for callbacks @@ -487,16 +488,18 @@ class TrainerTrainLoopMixin(ABC): # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory - step_out = batch_output.training_step_output_for_epoch_end - should_auto_reduce_train_result = isinstance(step_out, Result) and step_out.should_reduce_on_epoch_end - if isinstance(step_out, dict) and 'early_stop_on' in step_out: - early_stopping_accumulator.accumulate(step_out['early_stop_on']) + epoch_end_outputs = self.process_train_step_outputs( + batch_output.training_step_output_for_epoch_end, + early_stopping_accumulator, + checkpoint_accumulator + ) - if isinstance(step_out, dict) and 'checkpoint_on' in step_out: - checkpoint_accumulator.accumulate(step_out['checkpoint_on']) - - if self.is_overridden('training_epoch_end', model=self.get_model()) or should_auto_reduce_train_result: - epoch_output.append(batch_output.training_step_output_for_epoch_end) + # track the outputs to reduce at the end of the epoch + for opt_idx, opt_outputs in enumerate(epoch_end_outputs): + # with 1 step (no tbptt) don't use a sequence at epoch end + if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): + opt_outputs = opt_outputs[0] + epoch_output[opt_idx].append(opt_outputs) # update LR schedulers self.update_train_loop_lr_schedulers() @@ -538,7 +541,7 @@ class TrainerTrainLoopMixin(ABC): self.sync_horovod() # process epoch outputs - self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator) + self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers) # checkpoint callback self.check_checkpoint_callback(should_check_val) @@ -546,6 +549,35 @@ class TrainerTrainLoopMixin(ABC): # epoch end hook self.run_on_epoch_end_hook(model) + def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator): + """ + Figure out what needs to be tracked/logged at the end of the epoch + """ + + # the training step outputs a list per optimizer. The list contains the outputs at each time step + # when no TBPTT is used, then the list has 1 item per batch + # when TBPTT IS used, then the list has n items (1 per time step) + epoch_end_outputs = [] + for optimizer_idx_outputs in all_train_step_outputs: + # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer + sample_output = optimizer_idx_outputs[-1] + + # pull out callback info if available (ie: Results object) + if isinstance(sample_output, dict) and 'early_stop_on' in sample_output: + early_stopping_accumulator.accumulate(sample_output['early_stop_on']) + + if isinstance(sample_output, dict) and 'checkpoint_on' in sample_output: + checkpoint_accumulator.accumulate(sample_output['checkpoint_on']) + + # decide if we need to reduce at the end of the epoch automatically + auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end + + # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end + if self.is_overridden('training_epoch_end', model=self.get_model()) or auto_reduce_tng_result: + epoch_end_outputs.append(optimizer_idx_outputs) + + return epoch_end_outputs + def check_checkpoint_callback(self, should_check_val): # when no val loop is present or fast-dev-run still need to call checkpoints # TODO bake this logic into the checkpoint callback @@ -575,9 +607,12 @@ class TrainerTrainLoopMixin(ABC): if self.is_function_implemented('on_train_epoch_end'): model.on_train_epoch_end() - def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator): + def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers): + # epoch output is a list. Each item in that list has all the outputs per optimizer + # epoch_output[optimizer_idx][training_step_idx][tbptt_index] + # remember that not using truncated backprop is equivalent with truncated back prop of len(1) + model = self.get_model() - is_result_obj = len(epoch_output) > 0 and isinstance(epoch_output[0], Result) epoch_log_metrics = {} epoch_callback_metrics = {} @@ -592,17 +627,33 @@ class TrainerTrainLoopMixin(ABC): if early_stopping_accumulator.num_values > 0: epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean() + # ------------------------ + # determine if using a result obj + # ------------------------ + # [optimizer_idx][training_step_idx][tbptt_index] + opt_idx_outputs = epoch_output[0] + + try: + sample_obj = opt_idx_outputs[0][0] if isinstance(opt_idx_outputs[0], list) else opt_idx_outputs[0] + is_result_obj = len(epoch_output) > 0 and isinstance(sample_obj, Result) + except IndexError as e: + is_result_obj = False + # -------------------------- # EPOCH END STEP IF DEFINED # -------------------------- if self.is_overridden('training_epoch_end', model=model): self.global_step += 1 - # remove the protected keys so the user doesn't have to deal with them if is_result_obj: - epoch_output = epoch_output[0].__class__.gather(epoch_output) + # with result object gather across time and training steps so each opt idx has a single result obj + epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output) + + if num_optimizers == 1: + epoch_output = epoch_output[0] # run training_epoch_end + # a list with a result per optimizer index epoch_output = model.training_epoch_end(epoch_output) if isinstance(epoch_output, Result): @@ -618,10 +669,7 @@ class TrainerTrainLoopMixin(ABC): # Structured Result (auto epoch end) # -------------------------- elif is_result_obj: - epoch_output = epoch_output[0].__class__.reduce_on_epoch_end(epoch_output) - epoch_output.minimize = epoch_output.minimize.mean() - epoch_log_metrics = epoch_output.epoch_log_metrics - epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics + epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) # -------------------------- # track results @@ -637,6 +685,49 @@ class TrainerTrainLoopMixin(ABC): if len(epoch_progress_bar_metrics) > 0: self.add_progress_bar_metrics(epoch_progress_bar_metrics) + def __auto_reduce_results_on_epoch_end(self, epoch_output): + epoch_log_metrics = {} + epoch_progress_bar_metrics = {} + for opt_outputs in epoch_output: + # reduce across time first + time_reduced_outputs = [] + for train_step_idx in range(len(opt_outputs)): + tbptt_outs = opt_outputs[train_step_idx] + tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) + time_reduced_outputs.append(tbptt_outs) + + # reduce across training steps + opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) + opt_outputs.minimize = opt_outputs.minimize.mean() + epoch_log_metrics.update(opt_outputs.epoch_log_metrics) + epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics) + + return epoch_log_metrics, epoch_progress_bar_metrics + + def __gather_result_across_time_and_optimizers(self, epoch_output): + """ + Gather results into a single padded tensor per metric where each tensor is gathered across + time and across time steps. + + Returns: + a list where each element is a Result with the tensors gathered + """ + gathered_epoch_outputs = [] + for opt_outputs in epoch_output: + # gather across time first + time_gathered_outputs = [] + for train_step_idx in range(len(opt_outputs)): + tbptt_outs = opt_outputs[train_step_idx] + tbptt_outs = tbptt_outs[0].__class__.gather(tbptt_outs) + time_gathered_outputs.append(tbptt_outs) + + # gather across training steps + # each metric has dimensions (training_steps, seq_len) (seq_len=1 when no tbptt is used) + gathered_opt_output = time_gathered_outputs[0].__class__.padded_gather(time_gathered_outputs) + gathered_epoch_outputs.append(gathered_opt_output) + + return gathered_epoch_outputs + def sync_horovod(self): if self.use_horovod: hvd.join(hvd.local_rank() if self.on_gpu else -1) @@ -687,6 +778,9 @@ class TrainerTrainLoopMixin(ABC): using_results_obj = False + # track all outputs across time and num of optimizers + batch_outputs = [[] for i in range(len(self._get_optimizers_iterable()))] + if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) @@ -739,7 +833,7 @@ class TrainerTrainLoopMixin(ABC): batch_idx, opt_idx, optimizer, - self.hiddens, + self.hiddens ) using_results_obj = isinstance(opt_closure_result.training_step_output, Result) @@ -767,6 +861,9 @@ class TrainerTrainLoopMixin(ABC): # track hiddens self.hiddens = opt_closure_result.hiddens + if using_results_obj: + opt_closure_result.training_step_output_for_epoch_end.drop_hiddens() + # check if loss or model weights are nan if self.terminate_on_nan: self.detect_nan_tensors(opt_closure_result.loss) @@ -774,6 +871,9 @@ class TrainerTrainLoopMixin(ABC): # track total loss for logging (avoid mem leaks) self.batch_loss_value.append(opt_closure_result.loss) + # track all the outputs across all steps + batch_outputs[opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) + # ------------------------------ # BACKWARD PASS # ------------------------------ @@ -816,7 +916,7 @@ class TrainerTrainLoopMixin(ABC): signal=0, grad_norm_dic=grad_norm_dic, batch_log_metrics=batch_log_metrics, - training_step_output_for_epoch_end=opt_closure_result.training_step_output_for_epoch_end + training_step_output_for_epoch_end=batch_outputs ) return result diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index af9f662508..16d05680c9 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -77,6 +77,7 @@ class TrainingStepVariations(ABC): """ result.log('train_epoch_end_metric', 1, on_epoch=True) self.training_epoch_end_called = True + return result def eval_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None): diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 378d7f6a28..75846b7ecb 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -10,6 +10,7 @@ import tests.base.develop_utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.core.step_result import TrainResult from tests.base import EvalModelTemplate @@ -322,6 +323,160 @@ def test_tbptt_cpu_model(tmpdir): 'hiddens': self.test_hidden, } + def training_epoch_end(self, training_step_outputs): + training_step_outputs = training_step_outputs[0] + assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps) + loss = torch.stack([x['loss'] for x in training_step_outputs]).mean() + return {'log': {'train_loss': loss}} + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset=MockSeq2SeqDataset(), + batch_size=batch_size, + shuffle=False, + sampler=None, + ) + + hparams = EvalModelTemplate.get_default_hparams() + hparams.update( + batch_size=batch_size, + in_features=truncated_bptt_steps, + hidden_dim=truncated_bptt_steps, + out_features=truncated_bptt_steps + ) + + model = BpttTestModel(**hparams) + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + truncated_bptt_steps=truncated_bptt_steps, + limit_val_batches=0, + weights_summary=None, + early_stop_callback=False, + ) + result = trainer.fit(model) + + assert result == 1, 'training failed to complete' + + +def test_tbptt_cpu_model_result(tmpdir): + """Test truncated back propagation through time works.""" + truncated_bptt_steps = 2 + sequence_size = 30 + batch_size = 30 + + x_seq = torch.rand(batch_size, sequence_size, 1) + y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + def __getitem__(self, i): + return x_seq, y_seq_list + + def __len__(self): + return 1 + + class BpttTestModel(EvalModelTemplate): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.test_hidden = None + + def training_step(self, batch, batch_idx, hiddens): + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + self.test_hidden = torch.rand(1) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) + loss_val = torch.nn.functional.mse_loss( + pred, y_tensor.view(batch_size, truncated_bptt_steps)) + + result = TrainResult(loss_val, hiddens=self.test_hidden) + return result + + def training_epoch_end(self, training_step_outputs): + result = training_step_outputs + assert isinstance(result, TrainResult) + assert result.minimize.size(1) == (sequence_size / truncated_bptt_steps) + + result.minimize = result.minimize.mean() + return result + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset=MockSeq2SeqDataset(), + batch_size=batch_size, + shuffle=False, + sampler=None, + ) + + hparams = EvalModelTemplate.get_default_hparams() + hparams.update( + batch_size=batch_size, + in_features=truncated_bptt_steps, + hidden_dim=truncated_bptt_steps, + out_features=truncated_bptt_steps + ) + + model = BpttTestModel(**hparams) + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + truncated_bptt_steps=truncated_bptt_steps, + limit_val_batches=0, + weights_summary=None, + early_stop_callback=False, + ) + result = trainer.fit(model) + + assert result == 1, 'training failed to complete' + + +def test_tbptt_cpu_model_result_auto_reduce(tmpdir): + """Test truncated back propagation through time works.""" + truncated_bptt_steps = 2 + sequence_size = 30 + batch_size = 30 + + x_seq = torch.rand(batch_size, sequence_size, 1) + y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + def __getitem__(self, i): + return x_seq, y_seq_list + + def __len__(self): + return 1 + + class BpttTestModel(EvalModelTemplate): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.test_hidden = None + + def training_step(self, batch, batch_idx, hiddens): + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + self.test_hidden = torch.rand(1) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) + loss_val = torch.nn.functional.mse_loss( + pred, y_tensor.view(batch_size, truncated_bptt_steps)) + + result = TrainResult(loss_val, hiddens=self.test_hidden) + return result + def train_dataloader(self): return torch.utils.data.DataLoader( dataset=MockSeq2SeqDataset(), diff --git a/tests/trainer/test_trainer_steps_dict_return.py b/tests/trainer/test_trainer_steps_dict_return.py index 7d6df7a207..91dd9cbc75 100644 --- a/tests/trainer/test_trainer_steps_dict_return.py +++ b/tests/trainer/test_trainer_steps_dict_return.py @@ -35,6 +35,9 @@ def test_training_step_dict(tmpdir): assert out.batch_log_metrics['log_acc2'] == 7.0 train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + + train_step_out = train_step_out[0][0] pbar_metrics = train_step_out['progress_bar'] assert 'log' in train_step_out assert 'progress_bar' in train_step_out @@ -118,7 +121,10 @@ def test_full_training_loop_dict(tmpdir): assert out.batch_log_metrics['log_acc1'] == 14.0 assert out.batch_log_metrics['log_acc2'] == 9.0 + # get the output of the first optimizer train_step_end_out = out.training_step_output_for_epoch_end + assert len(train_step_end_out) == 1 + train_step_end_out = train_step_end_out[0][0] pbar_metrics = train_step_end_out['progress_bar'] assert pbar_metrics['pbar_acc1'] == 19.0 assert pbar_metrics['pbar_acc2'] == 21.0 @@ -158,7 +164,11 @@ def test_train_step_epoch_end(tmpdir): assert out.batch_log_metrics['log_acc1'] == 12.0 assert out.batch_log_metrics['log_acc2'] == 7.0 + # outputs are for 1 optimizer and no tbptt train_step_end_out = out.training_step_output_for_epoch_end + assert len(train_step_end_out) == 1 + train_step_end_out = train_step_end_out[0][0] + pbar_metrics = train_step_end_out['progress_bar'] assert pbar_metrics['pbar_acc1'] == 17.0 assert pbar_metrics['pbar_acc2'] == 19.0 diff --git a/tests/trainer/test_trainer_steps_result_return.py b/tests/trainer/test_trainer_steps_result_return.py index 1785fea3c0..62b0b6e483 100644 --- a/tests/trainer/test_trainer_steps_result_return.py +++ b/tests/trainer/test_trainer_steps_result_return.py @@ -74,6 +74,8 @@ def test_training_step_result_log_step_only(tmpdir): assert out.batch_log_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0 train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] assert isinstance(train_step_out, TrainResult) assert 'minimize' in train_step_out @@ -146,6 +148,8 @@ def test_training_step_result_log_epoch_only(tmpdir): assert len(out.batch_log_metrics) == 0 train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] assert isinstance(train_step_out, TrainResult) assert 'minimize' in train_step_out @@ -277,6 +281,8 @@ def test_training_step_result_log_step_and_epoch(tmpdir): assert len(out.batch_log_metrics) == 2 train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] assert isinstance(train_step_out, TrainResult) assert 'minimize' in train_step_out @@ -354,6 +360,8 @@ def test_training_step_epoch_end_result(tmpdir): assert len(out.batch_log_metrics) == 2 train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] assert isinstance(train_step_out, TrainResult) assert 'minimize' in train_step_out diff --git a/tests/trainer/test_trainer_steps_scalar_return.py b/tests/trainer/test_trainer_steps_scalar_return.py index e5eb1e9bcc..65a92a49de 100644 --- a/tests/trainer/test_trainer_steps_scalar_return.py +++ b/tests/trainer/test_trainer_steps_scalar_return.py @@ -37,6 +37,8 @@ def test_training_step_scalar(tmpdir): assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] assert isinstance(train_step_out, torch.Tensor) assert train_step_out.item() == 171 @@ -72,6 +74,8 @@ def training_step_scalar_with_step_end(tmpdir): assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] assert isinstance(train_step_out, torch.Tensor) assert train_step_out.item() == 171 @@ -117,6 +121,8 @@ def test_full_training_loop_scalar(tmpdir): assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] assert isinstance(train_step_out, torch.Tensor) assert train_step_out.item() == 171 @@ -158,6 +164,8 @@ def test_train_step_epoch_end_scalar(tmpdir): assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] assert isinstance(train_step_out, torch.Tensor) assert train_step_out.item() == 171