added tbptt test for logging (#3850)
* added tbptt test for logging * added tbptt test for logging
This commit is contained in:
parent
00f0d19a61
commit
2bca89a752
|
@ -275,3 +275,80 @@ def test__training_step__log_max_reduce_fx(tmpdir, batches, fx, result):
|
||||||
# make sure types are correct
|
# make sure types are correct
|
||||||
assert trainer.logged_metrics['foo'] == result
|
assert trainer.logged_metrics['foo'] == result
|
||||||
assert trainer.logged_metrics['bar'] == result
|
assert trainer.logged_metrics['bar'] == result
|
||||||
|
|
||||||
|
|
||||||
|
def test_tbptt_log(tmpdir):
|
||||||
|
"""
|
||||||
|
Tests that only training_step can be used
|
||||||
|
"""
|
||||||
|
truncated_bptt_steps = 2
|
||||||
|
sequence_size = 30
|
||||||
|
batch_size = 30
|
||||||
|
|
||||||
|
x_seq = torch.rand(batch_size, sequence_size, 1)
|
||||||
|
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
|
||||||
|
|
||||||
|
class MockSeq2SeqDataset(torch.utils.data.Dataset):
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return x_seq, y_seq_list
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.test_hidden = None
|
||||||
|
self.layer = torch.nn.Linear(2, 2)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, hiddens):
|
||||||
|
try:
|
||||||
|
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
self.test_hidden = torch.rand(1)
|
||||||
|
|
||||||
|
x_tensor, y_list = batch
|
||||||
|
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
|
||||||
|
|
||||||
|
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
|
||||||
|
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"
|
||||||
|
|
||||||
|
pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
|
||||||
|
loss_val = torch.nn.functional.mse_loss(
|
||||||
|
pred, y_tensor.view(batch_size, truncated_bptt_steps))
|
||||||
|
|
||||||
|
self.log('a', loss_val, on_epoch=True)
|
||||||
|
|
||||||
|
return {'loss': loss_val, 'hiddens': self.test_hidden}
|
||||||
|
|
||||||
|
def on_train_epoch_start(self) -> None:
|
||||||
|
self.test_hidden = None
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return torch.utils.data.DataLoader(
|
||||||
|
dataset=MockSeq2SeqDataset(),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
sampler=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = TestModel()
|
||||||
|
model.training_epoch_end = None
|
||||||
|
model.example_input_array = torch.randn(5, truncated_bptt_steps)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
limit_train_batches=10,
|
||||||
|
limit_val_batches=0,
|
||||||
|
truncated_bptt_steps=truncated_bptt_steps,
|
||||||
|
max_epochs=2,
|
||||||
|
row_log_interval=2,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
generated = set(trainer.logged_metrics.keys())
|
||||||
|
expected = {'a', 'step_a', 'epoch_a', 'epoch'}
|
||||||
|
assert generated == expected
|
||||||
|
|
Loading…
Reference in New Issue