Move NaN/Inf detection to a separate utilities file (#6834)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
90e37ba458
commit
851f9e3997
|
@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))
|
||||
|
||||
|
||||
- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))
|
||||
|
||||
|
@ -113,6 +115,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))
|
||||
|
||||
|
||||
- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
|
||||
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, par
|
|||
from pytorch_lightning.utilities.distributed import rank_zero_info
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
|
@ -636,7 +637,7 @@ class TrainLoop:
|
|||
|
||||
# check if loss or model weights are nan
|
||||
if self.trainer.terminate_on_nan:
|
||||
self.trainer.detect_nan_tensors(opt_closure_result.loss)
|
||||
self._check_finite(opt_closure_result.loss)
|
||||
|
||||
# track all the outputs across all steps
|
||||
batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
|
||||
|
@ -678,7 +679,7 @@ class TrainLoop:
|
|||
|
||||
# check if loss or model weights are nan
|
||||
if self.trainer.terminate_on_nan:
|
||||
self.trainer.detect_nan_tensors(result.loss)
|
||||
self._check_finite(result.loss)
|
||||
|
||||
else:
|
||||
self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...")
|
||||
|
@ -689,6 +690,12 @@ class TrainLoop:
|
|||
|
||||
return result
|
||||
|
||||
def _check_finite(self, loss: torch.Tensor) -> None:
|
||||
if not torch.isfinite(loss).all():
|
||||
raise ValueError(f'The loss returned in `training_step` is {loss}.')
|
||||
model = self.trainer.lightning_module
|
||||
detect_nan_parameters(model)
|
||||
|
||||
def backward(self, result, optimizer, opt_idx, *args, **kwargs):
|
||||
self.trainer.dev_debugger.track_event("backward_call")
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters, print_nan_gradients
|
||||
|
||||
EPSILON = 1e-6
|
||||
EPSILON_FP16 = 1e-5
|
||||
|
@ -26,28 +28,33 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TrainerTrainingTricksMixin(ABC):
|
||||
"""
|
||||
TODO: Remove this class in v1.5.
|
||||
|
||||
Use the NaN utilities from ``pytorch_lightning.utilities.finite_checks`` instead.
|
||||
"""
|
||||
|
||||
# this is just a summary on variables used in this abstract class,
|
||||
# the proper values/initialisation should be done in child class
|
||||
lightning_module: LightningModule
|
||||
|
||||
def print_nan_gradients(self) -> None:
|
||||
rank_zero_deprecation(
|
||||
"Internal: TrainerTrainingTricksMixin.print_nan_gradients is deprecated in v1.3"
|
||||
" and will be removed in v1.5."
|
||||
" Use `pytorch_lightning.utilities.finite_checks.print_nan_gradients` instead."
|
||||
)
|
||||
model = self.lightning_module
|
||||
for param in model.parameters():
|
||||
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
|
||||
log.info(param, param.grad)
|
||||
print_nan_gradients(model)
|
||||
|
||||
def detect_nan_tensors(self, loss: Tensor) -> None:
|
||||
model = self.lightning_module
|
||||
|
||||
rank_zero_deprecation(
|
||||
"Internal: TrainerTrainingTricksMixin.detect_nan_tensors is deprecated in v1.3"
|
||||
" and will be removed in v1.5."
|
||||
" Use `pytorch_lightning.utilities.finite_checks.detect_nan_parameters` instead."
|
||||
)
|
||||
# check if loss is nan
|
||||
if not torch.isfinite(loss).all():
|
||||
raise ValueError('The loss returned in `training_step` is nan or inf.')
|
||||
# check if a network weight is nan
|
||||
for name, param in model.named_parameters():
|
||||
if not torch.isfinite(param).all():
|
||||
self.print_nan_gradients()
|
||||
raise ValueError(
|
||||
f'Detected nan and/or inf values in `{name}`.'
|
||||
' Check your forward pass for numerically unstable operations.'
|
||||
)
|
||||
model = self.lightning_module
|
||||
detect_nan_parameters(model)
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helper functions to detect NaN/Inf values. """
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def print_nan_gradients(model: nn.Module) -> None:
|
||||
""" Iterates over model parameters and prints out parameter + gradient information if NaN. """
|
||||
for param in model.parameters():
|
||||
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
|
||||
log.info(param, param.grad)
|
||||
|
||||
|
||||
def detect_nan_parameters(model: nn.Module) -> None:
|
||||
""" Iterates over model parameters and prints gradients if any parameter is not finite. """
|
||||
for name, param in model.named_parameters():
|
||||
if not torch.isfinite(param).all():
|
||||
print_nan_gradients(model)
|
||||
raise ValueError(
|
||||
f'Detected nan and/or inf values in `{name}`.'
|
||||
' Check your forward pass for numerically unstable operations.'
|
||||
)
|
|
@ -15,6 +15,7 @@
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import optim
|
||||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
|
@ -218,3 +219,15 @@ def test_v1_5_0_profiler_output_filename(tmpdir, cls):
|
|||
profiler = cls(output_filename=filepath)
|
||||
assert profiler.dirpath == tmpdir
|
||||
assert profiler.filename == "test"
|
||||
|
||||
|
||||
def test_v1_5_0_trainer_training_trick_mixin(tmpdir):
|
||||
model = BoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False)
|
||||
trainer.fit(model)
|
||||
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
|
||||
trainer.print_nan_gradients()
|
||||
|
||||
dummy_loss = torch.tensor(1.0)
|
||||
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
|
||||
trainer.detect_nan_tensors(dummy_loss)
|
||||
|
|
|
@ -799,7 +799,7 @@ def test_nan_loss_detection(tmpdir):
|
|||
terminate_on_nan=True,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is nan or inf.*"):
|
||||
with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is.*"):
|
||||
trainer.fit(model)
|
||||
assert trainer.global_step == model.test_step_inf_loss
|
||||
|
||||
|
|
Loading…
Reference in New Issue