From 6a0d50369382ec3e7cce681facc9a2aad2a71d00 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 4 Jun 2021 17:34:39 +0200 Subject: [PATCH] Add warning to trainstep output (#7779) * Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update training_loop.py * Update pytorch_lightning/trainer/training_loop.py Co-authored-by: Ethan Harris * Update test_training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update pytorch_lightning/trainer/training_loop.py Co-authored-by: ananthsub * Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_training_loop.py * Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * escape regex Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris Co-authored-by: ananthsub --- pytorch_lightning/trainer/training_loop.py | 12 ++++++++++- tests/trainer/loops/test_training_loop.py | 24 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 156fea5d37..4b48cf0acb 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -16,7 +16,7 @@ from collections import OrderedDict from contextlib import contextmanager, suppress from copy import copy from functools import partial, update_wrapper -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import numpy as np import torch @@ -269,6 +269,16 @@ class TrainLoop: if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") + elif self.trainer.lightning_module.automatic_optimization: + if not any(( + isinstance(training_step_output, torch.Tensor), + (isinstance(training_step_output, Mapping) + and 'loss' in training_step_output), training_step_output is None + )): + raise MisconfigurationException( + "In automatic optimization, `training_step` must either return a Tensor, " + "a dict with key 'loss' or None (where the step will be skipped)." + ) def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index da4ecbe5a9..a2706e5d37 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re + import pytest import torch from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -142,3 +145,24 @@ def test_should_stop_mid_epoch(tmpdir): assert trainer.current_epoch == 0 assert trainer.global_step == 5 assert model.validation_called_at == (0, 4) + + +@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )]) +def test_warning_invalid_trainstep_output(tmpdir, output): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + return output + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + with pytest.raises( + MisconfigurationException, + match=re.escape( + "In automatic optimization, `training_step` must either return a Tensor, " + "a dict with key 'loss' or None (where the step will be skipped)." + ) + ): + trainer.fit(model)