Add warning to trainstep output (#7779)
* Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update training_loop.py * Update pytorch_lightning/trainer/training_loop.py Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> * Update test_training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update pytorch_lightning/trainer/training_loop.py Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_training_loop.py * Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * escape regex Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
ca89a7f344
commit
6a0d503693
|
@ -16,7 +16,7 @@ from collections import OrderedDict
|
|||
from contextlib import contextmanager, suppress
|
||||
from copy import copy
|
||||
from functools import partial, update_wrapper
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -269,6 +269,16 @@ class TrainLoop:
|
|||
if training_step_output.grad_fn is None:
|
||||
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
|
||||
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
|
||||
elif self.trainer.lightning_module.automatic_optimization:
|
||||
if not any((
|
||||
isinstance(training_step_output, torch.Tensor),
|
||||
(isinstance(training_step_output, Mapping)
|
||||
and 'loss' in training_step_output), training_step_output is None
|
||||
)):
|
||||
raise MisconfigurationException(
|
||||
"In automatic optimization, `training_step` must either return a Tensor, "
|
||||
"a dict with key 'loss' or None (where the step will be skipped)."
|
||||
)
|
||||
|
||||
def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
|
||||
# give the PL module a result for logging
|
||||
|
|
|
@ -11,10 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import seed_everything, Trainer
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers import BoringModel
|
||||
|
||||
|
||||
|
@ -142,3 +145,24 @@ def test_should_stop_mid_epoch(tmpdir):
|
|||
assert trainer.current_epoch == 0
|
||||
assert trainer.global_step == 5
|
||||
assert model.validation_called_at == (0, 4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )])
|
||||
def test_warning_invalid_trainstep_output(tmpdir, output):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
return output
|
||||
|
||||
model = TestModel()
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
|
||||
with pytest.raises(
|
||||
MisconfigurationException,
|
||||
match=re.escape(
|
||||
"In automatic optimization, `training_step` must either return a Tensor, "
|
||||
"a dict with key 'loss' or None (where the step will be skipped)."
|
||||
)
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue