add tests for single scalar return from training (#2587)
* add tests for single scalar return from training * add tests for single scalar return from training * add tests for single scalar return from training * add tests for single scalar return from training * add tests for single scalar return from training
This commit is contained in:
parent
a34609ef0e
commit
1d565e175d
|
@ -98,6 +98,17 @@ class TrainerLoggingMixin(ABC):
|
|||
|
||||
Separates loss from logging and progress bar metrics
|
||||
"""
|
||||
# --------------------------
|
||||
# handle single scalar only
|
||||
# --------------------------
|
||||
# single scalar returned from a xx_step
|
||||
if isinstance(output, torch.Tensor):
|
||||
progress_bar_metrics = {}
|
||||
log_metrics = {}
|
||||
callback_metrics = {}
|
||||
hiddens = None
|
||||
return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens
|
||||
|
||||
# ---------------
|
||||
# EXTRACT CALLBACK KEYS
|
||||
# ---------------
|
||||
|
|
|
@ -792,7 +792,10 @@ class TrainerTrainLoopMixin(ABC):
|
|||
)
|
||||
|
||||
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
|
||||
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
|
||||
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
|
||||
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
|
||||
else:
|
||||
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
|
||||
|
||||
# accumulate loss
|
||||
# (if accumulate_grad_batches = 1 no effect)
|
||||
|
|
|
@ -52,6 +52,47 @@ class DeterministicModel(LightningModule):
|
|||
|
||||
return num_graphs
|
||||
|
||||
# ---------------------------
|
||||
# scalar return
|
||||
# ---------------------------
|
||||
def training_step_scalar_return(self, batch, batch_idx):
|
||||
acc = self.step(batch, batch_idx)
|
||||
self.training_step_called = True
|
||||
return acc
|
||||
|
||||
def training_step_end_scalar(self, output):
|
||||
self.training_step_end_called = True
|
||||
|
||||
# make sure loss has the grad
|
||||
assert isinstance(output, torch.Tensor)
|
||||
assert output.grad_fn is not None
|
||||
|
||||
# make sure nothing else has grads
|
||||
assert self.count_num_graphs({'loss': output}) == 1
|
||||
|
||||
assert output == 171
|
||||
|
||||
return output
|
||||
|
||||
def training_epoch_end_scalar(self, outputs):
|
||||
"""
|
||||
There should be an array of scalars without graphs that are all 171 (4 of them)
|
||||
"""
|
||||
self.training_epoch_end_called = True
|
||||
|
||||
if self.use_dp or self.use_ddp2:
|
||||
pass
|
||||
else:
|
||||
# only saw 4 batches
|
||||
assert len(outputs) == 4
|
||||
for batch_out in outputs:
|
||||
assert batch_out == 171
|
||||
assert batch_out.grad_fn is None
|
||||
assert isinstance(batch_out, torch.Tensor)
|
||||
|
||||
prototype_loss = outputs[0]
|
||||
return prototype_loss
|
||||
|
||||
# --------------------------
|
||||
# dictionary returns
|
||||
# --------------------------
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
"""
|
||||
Tests to ensure that the training loop works with a dict
|
||||
"""
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.base.deterministic_model import DeterministicModel
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_training_step_dict(tmpdir):
|
||||
"""
|
||||
Tests that only training_step can be used
|
|
@ -0,0 +1,165 @@
|
|||
"""
|
||||
Tests to ensure that the training loop works with a scalar
|
||||
"""
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.base.deterministic_model import DeterministicModel
|
||||
import torch
|
||||
|
||||
|
||||
def test_training_step_scalar(tmpdir):
|
||||
"""
|
||||
Tests that only training_step that returns a single scalar can be used
|
||||
"""
|
||||
model = DeterministicModel()
|
||||
model.training_step = model.training_step_scalar_return
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.training_step_called
|
||||
assert not model.training_step_end_called
|
||||
assert not model.training_epoch_end_called
|
||||
|
||||
# make sure training outputs what is expected
|
||||
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||
break
|
||||
|
||||
out = trainer.run_training_batch(batch, batch_idx)
|
||||
assert out.signal == 0
|
||||
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
|
||||
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
|
||||
|
||||
train_step_out = out.training_step_output_for_epoch_end
|
||||
assert isinstance(train_step_out, torch.Tensor)
|
||||
assert train_step_out.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
||||
|
||||
def training_step_scalar_with_step_end(tmpdir):
|
||||
"""
|
||||
Checks train_step with scalar only + training_step_end
|
||||
"""
|
||||
model = DeterministicModel()
|
||||
model.training_step = model.training_step_scalar_return
|
||||
model.training_step_end = model.training_step_end_scalar
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(fast_dev_run=True, weights_summary=None)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.training_step_called
|
||||
assert model.training_step_end_called
|
||||
assert not model.training_epoch_end_called
|
||||
|
||||
# make sure training outputs what is expected
|
||||
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||
break
|
||||
|
||||
out = trainer.run_training_batch(batch, batch_idx)
|
||||
assert out.signal == 0
|
||||
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
|
||||
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
|
||||
|
||||
train_step_out = out.training_step_output_for_epoch_end
|
||||
assert isinstance(train_step_out, torch.Tensor)
|
||||
assert train_step_out.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
||||
|
||||
def test_full_training_loop_scalar(tmpdir):
|
||||
"""
|
||||
Checks train_step + training_step_end + training_epoch_end
|
||||
(all with scalar return from train_step)
|
||||
"""
|
||||
model = DeterministicModel()
|
||||
model.training_step = model.training_step_scalar_return
|
||||
model.training_step_end = model.training_step_end_scalar
|
||||
model.training_epoch_end = model.training_epoch_end_scalar
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.training_step_called
|
||||
assert model.training_step_end_called
|
||||
assert model.training_epoch_end_called
|
||||
|
||||
# assert epoch end metrics were added
|
||||
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
|
||||
assert len(trainer.progress_bar_metrics) == 0
|
||||
|
||||
# make sure training outputs what is expected
|
||||
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||
break
|
||||
|
||||
out = trainer.run_training_batch(batch, batch_idx)
|
||||
assert out.signal == 0
|
||||
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
|
||||
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
|
||||
|
||||
train_step_out = out.training_step_output_for_epoch_end
|
||||
assert isinstance(train_step_out, torch.Tensor)
|
||||
assert train_step_out.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
||||
|
||||
def test_train_step_epoch_end_scalar(tmpdir):
|
||||
"""
|
||||
Checks train_step + training_epoch_end (NO training_step_end)
|
||||
(with scalar return)
|
||||
"""
|
||||
model = DeterministicModel()
|
||||
model.training_step = model.training_step_scalar_return
|
||||
model.training_step_end = None
|
||||
model.training_epoch_end = model.training_epoch_end_scalar
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(max_epochs=1, weights_summary=None)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.training_step_called
|
||||
assert not model.training_step_end_called
|
||||
assert model.training_epoch_end_called
|
||||
|
||||
# assert epoch end metrics were added
|
||||
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
|
||||
assert len(trainer.progress_bar_metrics) == 0
|
||||
|
||||
# make sure training outputs what is expected
|
||||
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||
break
|
||||
|
||||
out = trainer.run_training_batch(batch, batch_idx)
|
||||
assert out.signal == 0
|
||||
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
|
||||
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
|
||||
|
||||
train_step_out = out.training_step_output_for_epoch_end
|
||||
assert isinstance(train_step_out, torch.Tensor)
|
||||
assert train_step_out.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'].item() == 171
|
Loading…
Reference in New Issue