Replace automatic nan check with optional flag (#1475)

* Replace automatic nan check with optional flag

* Update CHANGELOG.md
This commit is contained in:
Ethan Harris 2020-04-13 19:06:25 +01:00 committed by GitHub
parent 3f1e4b953f
commit 8544b334e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 5 deletions

View File

@ -9,13 +9,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems.
- Added learining rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))
- Added learning rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))
-
- Added `terminate_on_nan` flag to trainer that performs a NaN check with each training iteration when set to `True`. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))
### Changed
-
- Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))
-

View File

@ -134,6 +134,7 @@ class Trainer(
use_amp=None, # backward compatible, todo: remove in v0.9.0
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0
terminate_on_nan: bool = False,
**kwargs
):
r"""
@ -281,6 +282,9 @@ class Trainer(
To use a different key, set a string instead of True with the key name.
benchmark: If true enables cudnn.benchmark.
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.
"""
# Init callbacks
@ -357,6 +361,7 @@ class Trainer(
self.truncated_bptt_steps = truncated_bptt_steps
self.resume_from_checkpoint = resume_from_checkpoint
self.terminate_on_nan = terminate_on_nan
self.shown_warnings = set()
self.fast_dev_run = fast_dev_run

View File

@ -121,7 +121,8 @@ When this flag is enabled each batch is split into sequences of size truncated_b
NaN detection and intervention
------------------------------
In every forward pass in training, Lightning will check that
When the `terminate_on_nan` flag is enabled, after every forward pass during training, Lightning will
check that
1. the loss you return in `training_step` is finite (not NaN and not +/-inf)
2. the model parameters have finite values.
@ -130,6 +131,14 @@ Lightning will terminate the training loop with an error message if NaN or infin
values are detected. If this happens, you should investigate numerically unstable operations
in your model.
.. code-block:: python
# DEFAULT (won't perform the NaN check)
trainer = Trainer(terminate_on_nan=False)
# (NaN check each batch and terminate on NaN or infinite values)
trainer = Trainer(terminate_on_nan=True)
"""
import copy
@ -216,6 +225,7 @@ class TrainerTrainLoopMixin(ABC):
min_steps: int
total_batch_idx: int
checkpoint_callback: ...
terminate_on_nan: bool
# Callback system
callbacks: List[Callback]
@ -604,7 +614,8 @@ class TrainerTrainLoopMixin(ABC):
loss, batch_output = optimizer_closure()
# check if loss or model weights are nan
self.detect_nan_tensors(loss)
if self.terminate_on_nan:
self.detect_nan_tensors(loss)
# track total loss for logging (avoid mem leaks)
self.batch_loss_value.append(loss)

View File

@ -593,6 +593,7 @@ def test_nan_loss_detection(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=(test_step + 1),
terminate_on_nan=True
)
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
@ -619,6 +620,7 @@ def test_nan_params_detection(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=(test_step + 1),
terminate_on_nan=True
)
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):