diff --git a/tests/trainer/logging/test_train_loop_logging_1_0.py b/tests/trainer/logging/test_train_loop_logging_1_0.py index a852c4efe2..848e6e74a5 100644 --- a/tests/trainer/logging/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging/test_train_loop_logging_1_0.py @@ -275,3 +275,80 @@ def test__training_step__log_max_reduce_fx(tmpdir, batches, fx, result): # make sure types are correct assert trainer.logged_metrics['foo'] == result assert trainer.logged_metrics['bar'] == result + + +def test_tbptt_log(tmpdir): + """ + Tests that only training_step can be used + """ + truncated_bptt_steps = 2 + sequence_size = 30 + batch_size = 30 + + x_seq = torch.rand(batch_size, sequence_size, 1) + y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + def __getitem__(self, i): + return x_seq, y_seq_list + + def __len__(self): + return 1 + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.test_hidden = None + self.layer = torch.nn.Linear(2, 2) + + def training_step(self, batch, batch_idx, hiddens): + try: + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + except Exception as e: + print(e) + + self.test_hidden = torch.rand(1) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) + loss_val = torch.nn.functional.mse_loss( + pred, y_tensor.view(batch_size, truncated_bptt_steps)) + + self.log('a', loss_val, on_epoch=True) + + return {'loss': loss_val, 'hiddens': self.test_hidden} + + def on_train_epoch_start(self) -> None: + self.test_hidden = None + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset=MockSeq2SeqDataset(), + batch_size=batch_size, + shuffle=False, + sampler=None, + ) + + model = TestModel() + model.training_epoch_end = None + model.example_input_array = torch.randn(5, truncated_bptt_steps) + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=10, + limit_val_batches=0, + truncated_bptt_steps=truncated_bptt_steps, + max_epochs=2, + row_log_interval=2, + weights_summary=None, + ) + trainer.fit(model) + + generated = set(trainer.logged_metrics.keys()) + expected = {'a', 'step_a', 'epoch_a', 'epoch'} + assert generated == expected