From 851f9e3997db8d8e077254e3a996b8835d59dc3c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 8 Apr 2021 16:47:02 -0700 Subject: [PATCH] Move NaN/Inf detection to a separate utilities file (#6834) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: Jirka Borovec --- CHANGELOG.md | 5 +++ pytorch_lightning/trainer/training_loop.py | 11 +++++- pytorch_lightning/trainer/training_tricks.py | 33 ++++++++++------- pytorch_lightning/utilities/finite_checks.py | 39 ++++++++++++++++++++ tests/deprecated_api/test_remove_1-5.py | 13 +++++++ tests/trainer/test_trainer.py | 2 +- 6 files changed, 87 insertions(+), 16 deletions(-) create mode 100644 pytorch_lightning/utilities/finite_checks.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 504d4ea236..069bb96cd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/)) + - Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667)) @@ -113,6 +115,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/)) + + - `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 71d9407062..d274973381 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -29,6 +29,7 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, par from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.warnings import WarningCache @@ -636,7 +637,7 @@ class TrainLoop: # check if loss or model weights are nan if self.trainer.terminate_on_nan: - self.trainer.detect_nan_tensors(opt_closure_result.loss) + self._check_finite(opt_closure_result.loss) # track all the outputs across all steps batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 @@ -678,7 +679,7 @@ class TrainLoop: # check if loss or model weights are nan if self.trainer.terminate_on_nan: - self.trainer.detect_nan_tensors(result.loss) + self._check_finite(result.loss) else: self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...") @@ -689,6 +690,12 @@ class TrainLoop: return result + def _check_finite(self, loss: torch.Tensor) -> None: + if not torch.isfinite(loss).all(): + raise ValueError(f'The loss returned in `training_step` is {loss}.') + model = self.trainer.lightning_module + detect_nan_parameters(model) + def backward(self, result, optimizer, opt_idx, *args, **kwargs): self.trainer.dev_debugger.track_event("backward_call") diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 2795dd4f0a..42ab03de8d 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -19,6 +19,8 @@ import torch from torch import Tensor from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.finite_checks import detect_nan_parameters, print_nan_gradients EPSILON = 1e-6 EPSILON_FP16 = 1e-5 @@ -26,28 +28,33 @@ log = logging.getLogger(__name__) class TrainerTrainingTricksMixin(ABC): + """ + TODO: Remove this class in v1.5. + + Use the NaN utilities from ``pytorch_lightning.utilities.finite_checks`` instead. + """ # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class lightning_module: LightningModule def print_nan_gradients(self) -> None: + rank_zero_deprecation( + "Internal: TrainerTrainingTricksMixin.print_nan_gradients is deprecated in v1.3" + " and will be removed in v1.5." + " Use `pytorch_lightning.utilities.finite_checks.print_nan_gradients` instead." + ) model = self.lightning_module - for param in model.parameters(): - if (param.grad is not None) and torch.isnan(param.grad.float()).any(): - log.info(param, param.grad) + print_nan_gradients(model) def detect_nan_tensors(self, loss: Tensor) -> None: - model = self.lightning_module - + rank_zero_deprecation( + "Internal: TrainerTrainingTricksMixin.detect_nan_tensors is deprecated in v1.3" + " and will be removed in v1.5." + " Use `pytorch_lightning.utilities.finite_checks.detect_nan_parameters` instead." + ) # check if loss is nan if not torch.isfinite(loss).all(): raise ValueError('The loss returned in `training_step` is nan or inf.') - # check if a network weight is nan - for name, param in model.named_parameters(): - if not torch.isfinite(param).all(): - self.print_nan_gradients() - raise ValueError( - f'Detected nan and/or inf values in `{name}`.' - ' Check your forward pass for numerically unstable operations.' - ) + model = self.lightning_module + detect_nan_parameters(model) diff --git a/pytorch_lightning/utilities/finite_checks.py b/pytorch_lightning/utilities/finite_checks.py new file mode 100644 index 0000000000..770ea7a227 --- /dev/null +++ b/pytorch_lightning/utilities/finite_checks.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions to detect NaN/Inf values. """ + +import logging + +import torch +import torch.nn as nn + +log = logging.getLogger(__name__) + + +def print_nan_gradients(model: nn.Module) -> None: + """ Iterates over model parameters and prints out parameter + gradient information if NaN. """ + for param in model.parameters(): + if (param.grad is not None) and torch.isnan(param.grad.float()).any(): + log.info(param, param.grad) + + +def detect_nan_parameters(model: nn.Module) -> None: + """ Iterates over model parameters and prints gradients if any parameter is not finite. """ + for name, param in model.named_parameters(): + if not torch.isfinite(param).all(): + print_nan_gradients(model) + raise ValueError( + f'Detected nan and/or inf values in `{name}`.' + ' Check your forward pass for numerically unstable operations.' + ) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index fc3fe3112e..8757fb625d 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -15,6 +15,7 @@ from unittest import mock import pytest +import torch from torch import optim from pytorch_lightning import Callback, Trainer @@ -218,3 +219,15 @@ def test_v1_5_0_profiler_output_filename(tmpdir, cls): profiler = cls(output_filename=filepath) assert profiler.dirpath == tmpdir assert profiler.filename == "test" + + +def test_v1_5_0_trainer_training_trick_mixin(tmpdir): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False) + trainer.fit(model) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + trainer.print_nan_gradients() + + dummy_loss = torch.tensor(1.0) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + trainer.detect_nan_tensors(dummy_loss) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index cbba0b7a45..447ed2b41b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -799,7 +799,7 @@ def test_nan_loss_detection(tmpdir): terminate_on_nan=True, ) - with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is nan or inf.*"): + with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is.*"): trainer.fit(model) assert trainer.global_step == model.test_step_inf_loss