From 1d565e175d98103c2ebd6164e681f76143501da9 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 11 Jul 2020 17:43:00 -0400 Subject: [PATCH] add tests for single scalar return from training (#2587) * add tests for single scalar return from training * add tests for single scalar return from training * add tests for single scalar return from training * add tests for single scalar return from training * add tests for single scalar return from training --- pytorch_lightning/trainer/logging.py | 11 ++ pytorch_lightning/trainer/training_loop.py | 5 +- tests/base/deterministic_model.py | 41 +++++ ...s.py => test_trainer_steps_dict_return.py} | 6 +- .../test_trainer_steps_scalar_return.py | 165 ++++++++++++++++++ 5 files changed, 224 insertions(+), 4 deletions(-) rename tests/trainer/{test_trainer_steps.py => test_trainer_steps_dict_return.py} (97%) create mode 100644 tests/trainer/test_trainer_steps_scalar_return.py diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 5349849e09..35f5d5d35b 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -98,6 +98,17 @@ class TrainerLoggingMixin(ABC): Separates loss from logging and progress bar metrics """ + # -------------------------- + # handle single scalar only + # -------------------------- + # single scalar returned from a xx_step + if isinstance(output, torch.Tensor): + progress_bar_metrics = {} + log_metrics = {} + callback_metrics = {} + hiddens = None + return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens + # --------------- # EXTRACT CALLBACK KEYS # --------------- diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 008faa20ee..fa493f2e1b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -792,7 +792,10 @@ class TrainerTrainLoopMixin(ABC): ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs - training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) + if isinstance(training_step_output_for_epoch_end, torch.Tensor): + training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() + else: + training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) # accumulate loss # (if accumulate_grad_batches = 1 no effect) diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index 529d64f799..c387997da5 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -52,6 +52,47 @@ class DeterministicModel(LightningModule): return num_graphs + # --------------------------- + # scalar return + # --------------------------- + def training_step_scalar_return(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + self.training_step_called = True + return acc + + def training_step_end_scalar(self, output): + self.training_step_end_called = True + + # make sure loss has the grad + assert isinstance(output, torch.Tensor) + assert output.grad_fn is not None + + # make sure nothing else has grads + assert self.count_num_graphs({'loss': output}) == 1 + + assert output == 171 + + return output + + def training_epoch_end_scalar(self, outputs): + """ + There should be an array of scalars without graphs that are all 171 (4 of them) + """ + self.training_epoch_end_called = True + + if self.use_dp or self.use_ddp2: + pass + else: + # only saw 4 batches + assert len(outputs) == 4 + for batch_out in outputs: + assert batch_out == 171 + assert batch_out.grad_fn is None + assert isinstance(batch_out, torch.Tensor) + + prototype_loss = outputs[0] + return prototype_loss + # -------------------------- # dictionary returns # -------------------------- diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps_dict_return.py similarity index 97% rename from tests/trainer/test_trainer_steps.py rename to tests/trainer/test_trainer_steps_dict_return.py index 6091f48625..290983fbf6 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps_dict_return.py @@ -1,10 +1,10 @@ +""" +Tests to ensure that the training loop works with a dict +""" from pytorch_lightning import Trainer from tests.base.deterministic_model import DeterministicModel -import pytest -import torch -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_training_step_dict(tmpdir): """ Tests that only training_step can be used diff --git a/tests/trainer/test_trainer_steps_scalar_return.py b/tests/trainer/test_trainer_steps_scalar_return.py new file mode 100644 index 0000000000..b893b58310 --- /dev/null +++ b/tests/trainer/test_trainer_steps_scalar_return.py @@ -0,0 +1,165 @@ +""" +Tests to ensure that the training loop works with a scalar +""" +from pytorch_lightning import Trainer +from tests.base.deterministic_model import DeterministicModel +import torch + + +def test_training_step_scalar(tmpdir): + """ + Tests that only training_step that returns a single scalar can be used + """ + model = DeterministicModel() + model.training_step = model.training_step_scalar_return + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, torch.Tensor) + assert train_step_out.item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'].item() == 171 + + +def training_step_scalar_with_step_end(tmpdir): + """ + Checks train_step with scalar only + training_step_end + """ + model = DeterministicModel() + model.training_step = model.training_step_scalar_return + model.training_step_end = model.training_step_end_scalar + model.val_dataloader = None + + trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, torch.Tensor) + assert train_step_out.item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'].item() == 171 + + +def test_full_training_loop_scalar(tmpdir): + """ + Checks train_step + training_step_end + training_epoch_end + (all with scalar return from train_step) + """ + model = DeterministicModel() + model.training_step = model.training_step_scalar_return + model.training_step_end = model.training_step_end_scalar + model.training_epoch_end = model.training_epoch_end_scalar + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1 + assert len(trainer.progress_bar_metrics) == 0 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, torch.Tensor) + assert train_step_out.item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'].item() == 171 + + +def test_train_step_epoch_end_scalar(tmpdir): + """ + Checks train_step + training_epoch_end (NO training_step_end) + (with scalar return) + """ + model = DeterministicModel() + model.training_step = model.training_step_scalar_return + model.training_step_end = None + model.training_epoch_end = model.training_epoch_end_scalar + model.val_dataloader = None + + trainer = Trainer(max_epochs=1, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1 + assert len(trainer.progress_bar_metrics) == 0 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, torch.Tensor) + assert train_step_out.item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'].item() == 171