tracks all outputs including TBPTT and multiple optimizers (#2890)

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update

* pl 0.9 update
This commit is contained in:
William Falcon 2020-08-09 06:00:15 -04:00 committed by GitHub
parent 4d0406ec8b
commit 256059a1d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 485 additions and 49 deletions

View File

@ -92,6 +92,8 @@ class Result(Dict):
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
@ -113,15 +115,22 @@ class Result(Dict):
if on_step and on_epoch:
# set step version
step_name = f'step_{name}'
self.__set_meta(step_name, value, prog_bar, logger, on_step=True, on_epoch=False, reduce_fx=reduce_fx)
self.__set_meta(step_name, value, prog_bar, logger,
on_step=True, on_epoch=False,
reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
self.__setitem__(step_name, value)
# set epoch version
epoch_name = f'epoch_{name}'
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True, reduce_fx=reduce_fx)
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True,
reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
self.__setitem__(epoch_name, value)
else:
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)
self.__set_meta(name, value,
prog_bar, logger,
on_step, on_epoch,
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
# set the value
self.__setitem__(name, value)
@ -135,6 +144,8 @@ class Result(Dict):
on_step: bool,
on_epoch: bool,
reduce_fx: Callable,
tbptt_pad_token: int,
tbptt_reduce_fx: Callable
):
# set the meta for the item
meta_value = value
@ -144,7 +155,9 @@ class Result(Dict):
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
value=meta_value
value=meta_value,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token
)
self['meta'][name] = meta
@ -253,6 +266,39 @@ class Result(Dict):
result['meta'] = meta
return result
@classmethod
def padded_gather(cls, outputs):
meta = outputs[0].get('meta')
result = cls()
result = recursive_gather(outputs, result)
# find the padding used for other values
default_padding_idx = 0
for name, value in result.items():
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
if name not in {'checkpoint_on', 'early_stop_on', 'minimize'}:
default_padding_idx = meta[name]['tbptt_pad_token']
break
# pad across each key individually
for name, value in result.items():
is_reserved = name in {'checkpoint_on', 'early_stop_on', 'minimize'}
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
if is_reserved:
padding_key = default_padding_idx
else:
padding_key = meta[name]['tbptt_pad_token']
padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key)
result[name] = padded
# also update the result
if meta and not is_reserved:
meta[name]['value'] = padded
if meta:
result['meta'] = meta
return result
@classmethod
def reduce_on_epoch_end(cls, outputs):
meta = outputs[0]['meta']
@ -271,10 +317,36 @@ class Result(Dict):
result['meta'] = meta
return result
@classmethod
def reduce_across_time(cls, time_outputs):
# auto-reduce across time for tbptt
meta = time_outputs[0]['meta']
result = cls()
result = recursive_gather(time_outputs, result)
recursive_stack(result)
for k, value in result.items():
if k == 'meta':
continue
# pick the reduce fx
if k in ['checkpoint_on', 'early_stop_on', 'minimize']:
tbptt_reduce_fx = torch.mean
else:
tbptt_reduce_fx = meta[k]['tbptt_reduce_fx']
result[k] = tbptt_reduce_fx(value)
result['meta'] = meta
return result
@property
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']
def drop_hiddens(self):
if 'hiddens' in self:
del self['hiddens']
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
for out in outputs:
@ -303,6 +375,16 @@ def recursive_stack(result: MutableMapping):
result[k] = v
def recursive_padded_stack(result: MutableMapping):
for k, v in result.items():
if isinstance(v, dict):
recursive_stack(v)
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
v = torch.stack(v)
result[k] = v
class TrainResult(Result):
def __init__(
@ -348,6 +430,8 @@ class TrainResult(Result):
on_step: bool = True,
on_epoch: bool = False,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
@ -381,10 +465,26 @@ class TrainResult(Result):
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the training loop aggregated
reduce_fx: Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_ddp: if True, reduces the metric across GPUs/TPUs
sync_ddp_op: the op to sync across
sync_ddp_group: the ddp group
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
super().log(name=name,
value=value,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_ddp=sync_ddp,
sync_ddp_group=sync_ddp_group,
sync_ddp_op=sync_ddp_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx)
def log_dict(
self,
@ -394,6 +494,8 @@ class TrainResult(Result):
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
@ -408,17 +510,33 @@ class TrainResult(Result):
result.log_dict(values)
Args:
dictionary:
prog_bar:
logger:
on_step:
on_epoch:
reduce_fx:
enable_graph:
dictionary: key value pairs (str, tensors)
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the training loop aggregated
reduce_fx: Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_ddp: if True, reduces the metric across GPUs/TPUs
sync_ddp_op: the op to sync across
sync_ddp_group: the ddp group:
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
self.log(name=k,
value=v,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_ddp=sync_ddp,
sync_ddp_group=sync_ddp_group,
sync_ddp_op=sync_ddp_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx)
class EvalResult(Result):
@ -464,6 +582,8 @@ class EvalResult(Result):
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
@ -494,12 +614,28 @@ class EvalResult(Result):
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the validation loop or test loop aggregated
on_epoch: if True, logs the output of the training loop aggregated
reduce_fx: Torch.mean by default
enable_graph: if True, will not auto detach the graph :
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_ddp: if True, reduces the metric across GPUs/TPUs
sync_ddp_op: the op to sync across
sync_ddp_group: the ddp group
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
super().log(name=name,
value=value,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_ddp=sync_ddp,
sync_ddp_group=sync_ddp_group,
sync_ddp_op=sync_ddp_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx)
def log_dict(
self,
@ -509,6 +645,8 @@ class EvalResult(Result):
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
@ -523,17 +661,33 @@ class EvalResult(Result):
result.log_dict(values)
Args:
dictionary:
prog_bar:
logger:
on_step:
on_epoch:
reduce_fx:
enable_graph:
dictionary: key value pairs (str, tensors)
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the training loop aggregated
reduce_fx: Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_ddp: if True, reduces the metric across GPUs/TPUs
sync_ddp_op: the op to sync across
sync_ddp_group: the ddp group
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
self.log(name=k,
value=v,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_ddp=sync_ddp,
sync_ddp_group=sync_ddp_group,
sync_ddp_op=sync_ddp_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx)
def get_callback_metrics(self) -> dict:
result = {

View File

@ -462,7 +462,8 @@ class TrainerTrainLoopMixin(ABC):
train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader)
# bookkeeping
epoch_output = []
num_optimizers = len(self._get_optimizers_iterable())
epoch_output = [[] for _ in range(num_optimizers)]
should_check_val = False
# structured result accumulators for callbacks
@ -487,16 +488,18 @@ class TrainerTrainLoopMixin(ABC):
# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
step_out = batch_output.training_step_output_for_epoch_end
should_auto_reduce_train_result = isinstance(step_out, Result) and step_out.should_reduce_on_epoch_end
if isinstance(step_out, dict) and 'early_stop_on' in step_out:
early_stopping_accumulator.accumulate(step_out['early_stop_on'])
epoch_end_outputs = self.process_train_step_outputs(
batch_output.training_step_output_for_epoch_end,
early_stopping_accumulator,
checkpoint_accumulator
)
if isinstance(step_out, dict) and 'checkpoint_on' in step_out:
checkpoint_accumulator.accumulate(step_out['checkpoint_on'])
if self.is_overridden('training_epoch_end', model=self.get_model()) or should_auto_reduce_train_result:
epoch_output.append(batch_output.training_step_output_for_epoch_end)
# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
# with 1 step (no tbptt) don't use a sequence at epoch end
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
opt_outputs = opt_outputs[0]
epoch_output[opt_idx].append(opt_outputs)
# update LR schedulers
self.update_train_loop_lr_schedulers()
@ -538,7 +541,7 @@ class TrainerTrainLoopMixin(ABC):
self.sync_horovod()
# process epoch outputs
self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator)
self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers)
# checkpoint callback
self.check_checkpoint_callback(should_check_val)
@ -546,6 +549,35 @@ class TrainerTrainLoopMixin(ABC):
# epoch end hook
self.run_on_epoch_end_hook(model)
def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator):
"""
Figure out what needs to be tracked/logged at the end of the epoch
"""
# the training step outputs a list per optimizer. The list contains the outputs at each time step
# when no TBPTT is used, then the list has 1 item per batch
# when TBPTT IS used, then the list has n items (1 per time step)
epoch_end_outputs = []
for optimizer_idx_outputs in all_train_step_outputs:
# extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
sample_output = optimizer_idx_outputs[-1]
# pull out callback info if available (ie: Results object)
if isinstance(sample_output, dict) and 'early_stop_on' in sample_output:
early_stopping_accumulator.accumulate(sample_output['early_stop_on'])
if isinstance(sample_output, dict) and 'checkpoint_on' in sample_output:
checkpoint_accumulator.accumulate(sample_output['checkpoint_on'])
# decide if we need to reduce at the end of the epoch automatically
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
if self.is_overridden('training_epoch_end', model=self.get_model()) or auto_reduce_tng_result:
epoch_end_outputs.append(optimizer_idx_outputs)
return epoch_end_outputs
def check_checkpoint_callback(self, should_check_val):
# when no val loop is present or fast-dev-run still need to call checkpoints
# TODO bake this logic into the checkpoint callback
@ -575,9 +607,12 @@ class TrainerTrainLoopMixin(ABC):
if self.is_function_implemented('on_train_epoch_end'):
model.on_train_epoch_end()
def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator):
def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers):
# epoch output is a list. Each item in that list has all the outputs per optimizer
# epoch_output[optimizer_idx][training_step_idx][tbptt_index]
# remember that not using truncated backprop is equivalent with truncated back prop of len(1)
model = self.get_model()
is_result_obj = len(epoch_output) > 0 and isinstance(epoch_output[0], Result)
epoch_log_metrics = {}
epoch_callback_metrics = {}
@ -592,17 +627,33 @@ class TrainerTrainLoopMixin(ABC):
if early_stopping_accumulator.num_values > 0:
epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean()
# ------------------------
# determine if using a result obj
# ------------------------
# [optimizer_idx][training_step_idx][tbptt_index]
opt_idx_outputs = epoch_output[0]
try:
sample_obj = opt_idx_outputs[0][0] if isinstance(opt_idx_outputs[0], list) else opt_idx_outputs[0]
is_result_obj = len(epoch_output) > 0 and isinstance(sample_obj, Result)
except IndexError as e:
is_result_obj = False
# --------------------------
# EPOCH END STEP IF DEFINED
# --------------------------
if self.is_overridden('training_epoch_end', model=model):
self.global_step += 1
# remove the protected keys so the user doesn't have to deal with them
if is_result_obj:
epoch_output = epoch_output[0].__class__.gather(epoch_output)
# with result object gather across time and training steps so each opt idx has a single result obj
epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output)
if num_optimizers == 1:
epoch_output = epoch_output[0]
# run training_epoch_end
# a list with a result per optimizer index
epoch_output = model.training_epoch_end(epoch_output)
if isinstance(epoch_output, Result):
@ -618,10 +669,7 @@ class TrainerTrainLoopMixin(ABC):
# Structured Result (auto epoch end)
# --------------------------
elif is_result_obj:
epoch_output = epoch_output[0].__class__.reduce_on_epoch_end(epoch_output)
epoch_output.minimize = epoch_output.minimize.mean()
epoch_log_metrics = epoch_output.epoch_log_metrics
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)
# --------------------------
# track results
@ -637,6 +685,49 @@ class TrainerTrainLoopMixin(ABC):
if len(epoch_progress_bar_metrics) > 0:
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
def __auto_reduce_results_on_epoch_end(self, epoch_output):
epoch_log_metrics = {}
epoch_progress_bar_metrics = {}
for opt_outputs in epoch_output:
# reduce across time first
time_reduced_outputs = []
for train_step_idx in range(len(opt_outputs)):
tbptt_outs = opt_outputs[train_step_idx]
tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
time_reduced_outputs.append(tbptt_outs)
# reduce across training steps
opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs)
opt_outputs.minimize = opt_outputs.minimize.mean()
epoch_log_metrics.update(opt_outputs.epoch_log_metrics)
epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics)
return epoch_log_metrics, epoch_progress_bar_metrics
def __gather_result_across_time_and_optimizers(self, epoch_output):
"""
Gather results into a single padded tensor per metric where each tensor is gathered across
time and across time steps.
Returns:
a list where each element is a Result with the tensors gathered
"""
gathered_epoch_outputs = []
for opt_outputs in epoch_output:
# gather across time first
time_gathered_outputs = []
for train_step_idx in range(len(opt_outputs)):
tbptt_outs = opt_outputs[train_step_idx]
tbptt_outs = tbptt_outs[0].__class__.gather(tbptt_outs)
time_gathered_outputs.append(tbptt_outs)
# gather across training steps
# each metric has dimensions (training_steps, seq_len) (seq_len=1 when no tbptt is used)
gathered_opt_output = time_gathered_outputs[0].__class__.padded_gather(time_gathered_outputs)
gathered_epoch_outputs.append(gathered_opt_output)
return gathered_epoch_outputs
def sync_horovod(self):
if self.use_horovod:
hvd.join(hvd.local_rank() if self.on_gpu else -1)
@ -687,6 +778,9 @@ class TrainerTrainLoopMixin(ABC):
using_results_obj = False
# track all outputs across time and num of optimizers
batch_outputs = [[] for i in range(len(self._get_optimizers_iterable()))]
if batch is None:
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)
@ -739,7 +833,7 @@ class TrainerTrainLoopMixin(ABC):
batch_idx,
opt_idx,
optimizer,
self.hiddens,
self.hiddens
)
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)
@ -767,6 +861,9 @@ class TrainerTrainLoopMixin(ABC):
# track hiddens
self.hiddens = opt_closure_result.hiddens
if using_results_obj:
opt_closure_result.training_step_output_for_epoch_end.drop_hiddens()
# check if loss or model weights are nan
if self.terminate_on_nan:
self.detect_nan_tensors(opt_closure_result.loss)
@ -774,6 +871,9 @@ class TrainerTrainLoopMixin(ABC):
# track total loss for logging (avoid mem leaks)
self.batch_loss_value.append(opt_closure_result.loss)
# track all the outputs across all steps
batch_outputs[opt_idx].append(opt_closure_result.training_step_output_for_epoch_end)
# ------------------------------
# BACKWARD PASS
# ------------------------------
@ -816,7 +916,7 @@ class TrainerTrainLoopMixin(ABC):
signal=0,
grad_norm_dic=grad_norm_dic,
batch_log_metrics=batch_log_metrics,
training_step_output_for_epoch_end=opt_closure_result.training_step_output_for_epoch_end
training_step_output_for_epoch_end=batch_outputs
)
return result

View File

@ -77,6 +77,7 @@ class TrainingStepVariations(ABC):
"""
result.log('train_epoch_end_metric', 1, on_epoch=True)
self.training_epoch_end_called = True
return result
def eval_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):

View File

@ -10,6 +10,7 @@ import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.step_result import TrainResult
from tests.base import EvalModelTemplate
@ -322,6 +323,160 @@ def test_tbptt_cpu_model(tmpdir):
'hiddens': self.test_hidden,
}
def training_epoch_end(self, training_step_outputs):
training_step_outputs = training_step_outputs[0]
assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps)
loss = torch.stack([x['loss'] for x in training_step_outputs]).mean()
return {'log': {'train_loss': loss}}
def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(),
batch_size=batch_size,
shuffle=False,
sampler=None,
)
hparams = EvalModelTemplate.get_default_hparams()
hparams.update(
batch_size=batch_size,
in_features=truncated_bptt_steps,
hidden_dim=truncated_bptt_steps,
out_features=truncated_bptt_steps
)
model = BpttTestModel(**hparams)
# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
truncated_bptt_steps=truncated_bptt_steps,
limit_val_batches=0,
weights_summary=None,
early_stop_callback=False,
)
result = trainer.fit(model)
assert result == 1, 'training failed to complete'
def test_tbptt_cpu_model_result(tmpdir):
"""Test truncated back propagation through time works."""
truncated_bptt_steps = 2
sequence_size = 30
batch_size = 30
x_seq = torch.rand(batch_size, sequence_size, 1)
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
return x_seq, y_seq_list
def __len__(self):
return 1
class BpttTestModel(EvalModelTemplate):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.test_hidden = None
def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
self.test_hidden = torch.rand(1)
x_tensor, y_list = batch
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"
pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
loss_val = torch.nn.functional.mse_loss(
pred, y_tensor.view(batch_size, truncated_bptt_steps))
result = TrainResult(loss_val, hiddens=self.test_hidden)
return result
def training_epoch_end(self, training_step_outputs):
result = training_step_outputs
assert isinstance(result, TrainResult)
assert result.minimize.size(1) == (sequence_size / truncated_bptt_steps)
result.minimize = result.minimize.mean()
return result
def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(),
batch_size=batch_size,
shuffle=False,
sampler=None,
)
hparams = EvalModelTemplate.get_default_hparams()
hparams.update(
batch_size=batch_size,
in_features=truncated_bptt_steps,
hidden_dim=truncated_bptt_steps,
out_features=truncated_bptt_steps
)
model = BpttTestModel(**hparams)
# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
truncated_bptt_steps=truncated_bptt_steps,
limit_val_batches=0,
weights_summary=None,
early_stop_callback=False,
)
result = trainer.fit(model)
assert result == 1, 'training failed to complete'
def test_tbptt_cpu_model_result_auto_reduce(tmpdir):
"""Test truncated back propagation through time works."""
truncated_bptt_steps = 2
sequence_size = 30
batch_size = 30
x_seq = torch.rand(batch_size, sequence_size, 1)
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
return x_seq, y_seq_list
def __len__(self):
return 1
class BpttTestModel(EvalModelTemplate):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.test_hidden = None
def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
self.test_hidden = torch.rand(1)
x_tensor, y_list = batch
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"
pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
loss_val = torch.nn.functional.mse_loss(
pred, y_tensor.view(batch_size, truncated_bptt_steps))
result = TrainResult(loss_val, hiddens=self.test_hidden)
return result
def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(),

View File

@ -35,6 +35,9 @@ def test_training_step_dict(tmpdir):
assert out.batch_log_metrics['log_acc2'] == 7.0
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
pbar_metrics = train_step_out['progress_bar']
assert 'log' in train_step_out
assert 'progress_bar' in train_step_out
@ -118,7 +121,10 @@ def test_full_training_loop_dict(tmpdir):
assert out.batch_log_metrics['log_acc1'] == 14.0
assert out.batch_log_metrics['log_acc2'] == 9.0
# get the output of the first optimizer
train_step_end_out = out.training_step_output_for_epoch_end
assert len(train_step_end_out) == 1
train_step_end_out = train_step_end_out[0][0]
pbar_metrics = train_step_end_out['progress_bar']
assert pbar_metrics['pbar_acc1'] == 19.0
assert pbar_metrics['pbar_acc2'] == 21.0
@ -158,7 +164,11 @@ def test_train_step_epoch_end(tmpdir):
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0
# outputs are for 1 optimizer and no tbptt
train_step_end_out = out.training_step_output_for_epoch_end
assert len(train_step_end_out) == 1
train_step_end_out = train_step_end_out[0][0]
pbar_metrics = train_step_end_out['progress_bar']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

View File

@ -74,6 +74,8 @@ def test_training_step_result_log_step_only(tmpdir):
assert out.batch_log_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
@ -146,6 +148,8 @@ def test_training_step_result_log_epoch_only(tmpdir):
assert len(out.batch_log_metrics) == 0
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
@ -277,6 +281,8 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert len(out.batch_log_metrics) == 2
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
@ -354,6 +360,8 @@ def test_training_step_epoch_end_result(tmpdir):
assert len(out.batch_log_metrics) == 2
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out

View File

@ -37,6 +37,8 @@ def test_training_step_scalar(tmpdir):
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171
@ -72,6 +74,8 @@ def training_step_scalar_with_step_end(tmpdir):
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171
@ -117,6 +121,8 @@ def test_full_training_loop_scalar(tmpdir):
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171
@ -158,6 +164,8 @@ def test_train_step_epoch_end_scalar(tmpdir):
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171