added tbptt test for logging (#3850)
* added tbptt test for logging * added tbptt test for logging
This commit is contained in:
parent
00f0d19a61
commit
2bca89a752
|
@ -275,3 +275,80 @@ def test__training_step__log_max_reduce_fx(tmpdir, batches, fx, result):
|
|||
# make sure types are correct
|
||||
assert trainer.logged_metrics['foo'] == result
|
||||
assert trainer.logged_metrics['bar'] == result
|
||||
|
||||
|
||||
def test_tbptt_log(tmpdir):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
"""
|
||||
truncated_bptt_steps = 2
|
||||
sequence_size = 30
|
||||
batch_size = 30
|
||||
|
||||
x_seq = torch.rand(batch_size, sequence_size, 1)
|
||||
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
|
||||
|
||||
class MockSeq2SeqDataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, i):
|
||||
return x_seq, y_seq_list
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.test_hidden = None
|
||||
self.layer = torch.nn.Linear(2, 2)
|
||||
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
try:
|
||||
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
self.test_hidden = torch.rand(1)
|
||||
|
||||
x_tensor, y_list = batch
|
||||
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
|
||||
|
||||
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
|
||||
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"
|
||||
|
||||
pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
|
||||
loss_val = torch.nn.functional.mse_loss(
|
||||
pred, y_tensor.view(batch_size, truncated_bptt_steps))
|
||||
|
||||
self.log('a', loss_val, on_epoch=True)
|
||||
|
||||
return {'loss': loss_val, 'hiddens': self.test_hidden}
|
||||
|
||||
def on_train_epoch_start(self) -> None:
|
||||
self.test_hidden = None
|
||||
|
||||
def train_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=MockSeq2SeqDataset(),
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
sampler=None,
|
||||
)
|
||||
|
||||
model = TestModel()
|
||||
model.training_epoch_end = None
|
||||
model.example_input_array = torch.randn(5, truncated_bptt_steps)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=0,
|
||||
truncated_bptt_steps=truncated_bptt_steps,
|
||||
max_epochs=2,
|
||||
row_log_interval=2,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
generated = set(trainer.logged_metrics.keys())
|
||||
expected = {'a', 'step_a', 'epoch_a', 'epoch'}
|
||||
assert generated == expected
|
||||
|
|
Loading…
Reference in New Issue