add tests for single scalar return from training (#2587)

* add tests for single scalar return from training

* add tests for single scalar return from training

* add tests for single scalar return from training

* add tests for single scalar return from training

* add tests for single scalar return from training
This commit is contained in:
William Falcon 2020-07-11 17:43:00 -04:00 committed by GitHub
parent a34609ef0e
commit 1d565e175d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 224 additions and 4 deletions

View File

@ -98,6 +98,17 @@ class TrainerLoggingMixin(ABC):
Separates loss from logging and progress bar metrics
"""
# --------------------------
# handle single scalar only
# --------------------------
# single scalar returned from a xx_step
if isinstance(output, torch.Tensor):
progress_bar_metrics = {}
log_metrics = {}
callback_metrics = {}
hiddens = None
return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens
# ---------------
# EXTRACT CALLBACK KEYS
# ---------------

View File

@ -792,7 +792,10 @@ class TrainerTrainLoopMixin(ABC):
)
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
else:
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
# accumulate loss
# (if accumulate_grad_batches = 1 no effect)

View File

@ -52,6 +52,47 @@ class DeterministicModel(LightningModule):
return num_graphs
# ---------------------------
# scalar return
# ---------------------------
def training_step_scalar_return(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
self.training_step_called = True
return acc
def training_step_end_scalar(self, output):
self.training_step_end_called = True
# make sure loss has the grad
assert isinstance(output, torch.Tensor)
assert output.grad_fn is not None
# make sure nothing else has grads
assert self.count_num_graphs({'loss': output}) == 1
assert output == 171
return output
def training_epoch_end_scalar(self, outputs):
"""
There should be an array of scalars without graphs that are all 171 (4 of them)
"""
self.training_epoch_end_called = True
if self.use_dp or self.use_ddp2:
pass
else:
# only saw 4 batches
assert len(outputs) == 4
for batch_out in outputs:
assert batch_out == 171
assert batch_out.grad_fn is None
assert isinstance(batch_out, torch.Tensor)
prototype_loss = outputs[0]
return prototype_loss
# --------------------------
# dictionary returns
# --------------------------

View File

@ -1,10 +1,10 @@
"""
Tests to ensure that the training loop works with a dict
"""
from pytorch_lightning import Trainer
from tests.base.deterministic_model import DeterministicModel
import pytest
import torch
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_training_step_dict(tmpdir):
"""
Tests that only training_step can be used

View File

@ -0,0 +1,165 @@
"""
Tests to ensure that the training loop works with a scalar
"""
from pytorch_lightning import Trainer
from tests.base.deterministic_model import DeterministicModel
import torch
def test_training_step_scalar(tmpdir):
"""
Tests that only training_step that returns a single scalar can be used
"""
model = DeterministicModel()
model.training_step = model.training_step_scalar_return
model.val_dataloader = None
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert not model.training_epoch_end_called
# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171
def training_step_scalar_with_step_end(tmpdir):
"""
Checks train_step with scalar only + training_step_end
"""
model = DeterministicModel()
model.training_step = model.training_step_scalar_return
model.training_step_end = model.training_step_end_scalar
model.val_dataloader = None
trainer = Trainer(fast_dev_run=True, weights_summary=None)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert model.training_step_end_called
assert not model.training_epoch_end_called
# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171
def test_full_training_loop_scalar(tmpdir):
"""
Checks train_step + training_step_end + training_epoch_end
(all with scalar return from train_step)
"""
model = DeterministicModel()
model.training_step = model.training_step_scalar_return
model.training_step_end = model.training_step_end_scalar
model.training_epoch_end = model.training_epoch_end_scalar
model.val_dataloader = None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert model.training_step_end_called
assert model.training_epoch_end_called
# assert epoch end metrics were added
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
assert len(trainer.progress_bar_metrics) == 0
# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171
def test_train_step_epoch_end_scalar(tmpdir):
"""
Checks train_step + training_epoch_end (NO training_step_end)
(with scalar return)
"""
model = DeterministicModel()
model.training_step = model.training_step_scalar_return
model.training_step_end = None
model.training_epoch_end = model.training_epoch_end_scalar
model.val_dataloader = None
trainer = Trainer(max_epochs=1, weights_summary=None)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert model.training_epoch_end_called
# assert epoch end metrics were added
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
assert len(trainer.progress_bar_metrics) == 0
# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171