Move NaN/Inf detection to a separate utilities file (#6834)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
ananthsub 2021-04-08 16:47:02 -07:00 committed by GitHub
parent 90e37ba458
commit 851f9e3997
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 87 additions and 16 deletions

View File

@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))
- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))
@ -113,6 +115,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
- Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))
- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))

View File

@ -29,6 +29,7 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, par
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.warnings import WarningCache
@ -636,7 +637,7 @@ class TrainLoop:
# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(opt_closure_result.loss)
self._check_finite(opt_closure_result.loss)
# track all the outputs across all steps
batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
@ -678,7 +679,7 @@ class TrainLoop:
# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(result.loss)
self._check_finite(result.loss)
else:
self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...")
@ -689,6 +690,12 @@ class TrainLoop:
return result
def _check_finite(self, loss: torch.Tensor) -> None:
if not torch.isfinite(loss).all():
raise ValueError(f'The loss returned in `training_step` is {loss}.')
model = self.trainer.lightning_module
detect_nan_parameters(model)
def backward(self, result, optimizer, opt_idx, *args, **kwargs):
self.trainer.dev_debugger.track_event("backward_call")

View File

@ -19,6 +19,8 @@ import torch
from torch import Tensor
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters, print_nan_gradients
EPSILON = 1e-6
EPSILON_FP16 = 1e-5
@ -26,28 +28,33 @@ log = logging.getLogger(__name__)
class TrainerTrainingTricksMixin(ABC):
"""
TODO: Remove this class in v1.5.
Use the NaN utilities from ``pytorch_lightning.utilities.finite_checks`` instead.
"""
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
lightning_module: LightningModule
def print_nan_gradients(self) -> None:
rank_zero_deprecation(
"Internal: TrainerTrainingTricksMixin.print_nan_gradients is deprecated in v1.3"
" and will be removed in v1.5."
" Use `pytorch_lightning.utilities.finite_checks.print_nan_gradients` instead."
)
model = self.lightning_module
for param in model.parameters():
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
log.info(param, param.grad)
print_nan_gradients(model)
def detect_nan_tensors(self, loss: Tensor) -> None:
model = self.lightning_module
rank_zero_deprecation(
"Internal: TrainerTrainingTricksMixin.detect_nan_tensors is deprecated in v1.3"
" and will be removed in v1.5."
" Use `pytorch_lightning.utilities.finite_checks.detect_nan_parameters` instead."
)
# check if loss is nan
if not torch.isfinite(loss).all():
raise ValueError('The loss returned in `training_step` is nan or inf.')
# check if a network weight is nan
for name, param in model.named_parameters():
if not torch.isfinite(param).all():
self.print_nan_gradients()
raise ValueError(
f'Detected nan and/or inf values in `{name}`.'
' Check your forward pass for numerically unstable operations.'
)
model = self.lightning_module
detect_nan_parameters(model)

View File

@ -0,0 +1,39 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions to detect NaN/Inf values. """
import logging
import torch
import torch.nn as nn
log = logging.getLogger(__name__)
def print_nan_gradients(model: nn.Module) -> None:
""" Iterates over model parameters and prints out parameter + gradient information if NaN. """
for param in model.parameters():
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
log.info(param, param.grad)
def detect_nan_parameters(model: nn.Module) -> None:
""" Iterates over model parameters and prints gradients if any parameter is not finite. """
for name, param in model.named_parameters():
if not torch.isfinite(param).all():
print_nan_gradients(model)
raise ValueError(
f'Detected nan and/or inf values in `{name}`.'
' Check your forward pass for numerically unstable operations.'
)

View File

@ -15,6 +15,7 @@
from unittest import mock
import pytest
import torch
from torch import optim
from pytorch_lightning import Callback, Trainer
@ -218,3 +219,15 @@ def test_v1_5_0_profiler_output_filename(tmpdir, cls):
profiler = cls(output_filename=filepath)
assert profiler.dirpath == tmpdir
assert profiler.filename == "test"
def test_v1_5_0_trainer_training_trick_mixin(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False)
trainer.fit(model)
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
trainer.print_nan_gradients()
dummy_loss = torch.tensor(1.0)
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
trainer.detect_nan_tensors(dummy_loss)

View File

@ -799,7 +799,7 @@ def test_nan_loss_detection(tmpdir):
terminate_on_nan=True,
)
with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is nan or inf.*"):
with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is.*"):
trainer.fit(model)
assert trainer.global_step == model.test_step_inf_loss