From 8544b334e4af9caa060a280146e7d3bb10648332 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 13 Apr 2020 19:06:25 +0100 Subject: [PATCH] Replace automatic nan check with optional flag (#1475) * Replace automatic nan check with optional flag * Update CHANGELOG.md --- CHANGELOG.md | 6 +++--- pytorch_lightning/trainer/trainer.py | 5 +++++ pytorch_lightning/trainer/training_loop.py | 15 +++++++++++++-- tests/trainer/test_trainer.py | 2 ++ 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7430c95d38..b5215646ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,13 +9,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems. -- Added learining rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347)) +- Added learning rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347)) -- +- Added `terminate_on_nan` flag to trainer that performs a NaN check with each training iteration when set to `True`. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475)) ### Changed -- +- Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475)) - diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0e2e4c15cf..46b4fc48e6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -134,6 +134,7 @@ class Trainer( use_amp=None, # backward compatible, todo: remove in v0.9.0 show_progress_bar=None, # backward compatible, todo: remove in v0.9.0 nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0 + terminate_on_nan: bool = False, **kwargs ): r""" @@ -281,6 +282,9 @@ class Trainer( To use a different key, set a string instead of True with the key name. benchmark: If true enables cudnn.benchmark. + + terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the + end of each training batch, if any of the parameters or the loss are NaN or +/-inf. """ # Init callbacks @@ -357,6 +361,7 @@ class Trainer( self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint + self.terminate_on_nan = terminate_on_nan self.shown_warnings = set() self.fast_dev_run = fast_dev_run diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 06ac7849d7..5a84422049 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,7 +121,8 @@ When this flag is enabled each batch is split into sequences of size truncated_b NaN detection and intervention ------------------------------ -In every forward pass in training, Lightning will check that +When the `terminate_on_nan` flag is enabled, after every forward pass during training, Lightning will +check that 1. the loss you return in `training_step` is finite (not NaN and not +/-inf) 2. the model parameters have finite values. @@ -130,6 +131,14 @@ Lightning will terminate the training loop with an error message if NaN or infin values are detected. If this happens, you should investigate numerically unstable operations in your model. +.. code-block:: python + + # DEFAULT (won't perform the NaN check) + trainer = Trainer(terminate_on_nan=False) + + # (NaN check each batch and terminate on NaN or infinite values) + trainer = Trainer(terminate_on_nan=True) + """ import copy @@ -216,6 +225,7 @@ class TrainerTrainLoopMixin(ABC): min_steps: int total_batch_idx: int checkpoint_callback: ... + terminate_on_nan: bool # Callback system callbacks: List[Callback] @@ -604,7 +614,8 @@ class TrainerTrainLoopMixin(ABC): loss, batch_output = optimizer_closure() # check if loss or model weights are nan - self.detect_nan_tensors(loss) + if self.terminate_on_nan: + self.detect_nan_tensors(loss) # track total loss for logging (avoid mem leaks) self.batch_loss_value.append(loss) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 02d61c9cc7..2a0f9180a9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -593,6 +593,7 @@ def test_nan_loss_detection(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_steps=(test_step + 1), + terminate_on_nan=True ) with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'): @@ -619,6 +620,7 @@ def test_nan_params_detection(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_steps=(test_step + 1), + terminate_on_nan=True ) with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):