added tbptt test for logging (#3850)

* added tbptt test for logging

* added tbptt test for logging
This commit is contained in:
William Falcon 2020-10-04 19:38:42 -04:00 committed by GitHub
parent 00f0d19a61
commit 2bca89a752
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 77 additions and 0 deletions

View File

@ -275,3 +275,80 @@ def test__training_step__log_max_reduce_fx(tmpdir, batches, fx, result):
# make sure types are correct
assert trainer.logged_metrics['foo'] == result
assert trainer.logged_metrics['bar'] == result
def test_tbptt_log(tmpdir):
"""
Tests that only training_step can be used
"""
truncated_bptt_steps = 2
sequence_size = 30
batch_size = 30
x_seq = torch.rand(batch_size, sequence_size, 1)
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
return x_seq, y_seq_list
def __len__(self):
return 1
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.test_hidden = None
self.layer = torch.nn.Linear(2, 2)
def training_step(self, batch, batch_idx, hiddens):
try:
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
except Exception as e:
print(e)
self.test_hidden = torch.rand(1)
x_tensor, y_list = batch
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"
pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
loss_val = torch.nn.functional.mse_loss(
pred, y_tensor.view(batch_size, truncated_bptt_steps))
self.log('a', loss_val, on_epoch=True)
return {'loss': loss_val, 'hiddens': self.test_hidden}
def on_train_epoch_start(self) -> None:
self.test_hidden = None
def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(),
batch_size=batch_size,
shuffle=False,
sampler=None,
)
model = TestModel()
model.training_epoch_end = None
model.example_input_array = torch.randn(5, truncated_bptt_steps)
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=10,
limit_val_batches=0,
truncated_bptt_steps=truncated_bptt_steps,
max_epochs=2,
row_log_interval=2,
weights_summary=None,
)
trainer.fit(model)
generated = set(trainer.logged_metrics.keys())
expected = {'a', 'step_a', 'epoch_a', 'epoch'}
assert generated == expected