Add warning to trainstep output (#7779)

* Update training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update training_loop.py

* Update pytorch_lightning/trainer/training_loop.py

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* Update test_training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update pytorch_lightning/trainer/training_loop.py

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* Update training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_training_loop.py

* Update training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* escape regex

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
Justus Schock 2021-06-04 17:34:39 +02:00 committed by GitHub
parent ca89a7f344
commit 6a0d503693
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 1 deletions

View File

@ -16,7 +16,7 @@ from collections import OrderedDict
from contextlib import contextmanager, suppress
from copy import copy
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import numpy as np
import torch
@ -269,6 +269,16 @@ class TrainLoop:
if training_step_output.grad_fn is None:
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
elif self.trainer.lightning_module.automatic_optimization:
if not any((
isinstance(training_step_output, torch.Tensor),
(isinstance(training_step_output, Mapping)
and 'loss' in training_step_output), training_step_output is None
)):
raise MisconfigurationException(
"In automatic optimization, `training_step` must either return a Tensor, "
"a dict with key 'loss' or None (where the step will be skipped)."
)
def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
# give the PL module a result for logging

View File

@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import pytest
import torch
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
@ -142,3 +145,24 @@ def test_should_stop_mid_epoch(tmpdir):
assert trainer.current_epoch == 0
assert trainer.global_step == 5
assert model.validation_called_at == (0, 4)
@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )])
def test_warning_invalid_trainstep_output(tmpdir, output):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
return output
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
with pytest.raises(
MisconfigurationException,
match=re.escape(
"In automatic optimization, `training_step` must either return a Tensor, "
"a dict with key 'loss' or None (where the step will be skipped)."
)
):
trainer.fit(model)