Replace automatic nan check with optional flag (#1475)
* Replace automatic nan check with optional flag * Update CHANGELOG.md
This commit is contained in:
parent
3f1e4b953f
commit
8544b334e4
|
@ -9,13 +9,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Added
|
||||
|
||||
- Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems.
|
||||
- Added learining rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))
|
||||
- Added learning rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))
|
||||
|
||||
-
|
||||
- Added `terminate_on_nan` flag to trainer that performs a NaN check with each training iteration when set to `True`. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))
|
||||
|
||||
### Changed
|
||||
|
||||
-
|
||||
- Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))
|
||||
|
||||
-
|
||||
|
||||
|
|
|
@ -134,6 +134,7 @@ class Trainer(
|
|||
use_amp=None, # backward compatible, todo: remove in v0.9.0
|
||||
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
|
||||
nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0
|
||||
terminate_on_nan: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
|
@ -281,6 +282,9 @@ class Trainer(
|
|||
To use a different key, set a string instead of True with the key name.
|
||||
|
||||
benchmark: If true enables cudnn.benchmark.
|
||||
|
||||
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
|
||||
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.
|
||||
"""
|
||||
|
||||
# Init callbacks
|
||||
|
@ -357,6 +361,7 @@ class Trainer(
|
|||
|
||||
self.truncated_bptt_steps = truncated_bptt_steps
|
||||
self.resume_from_checkpoint = resume_from_checkpoint
|
||||
self.terminate_on_nan = terminate_on_nan
|
||||
self.shown_warnings = set()
|
||||
|
||||
self.fast_dev_run = fast_dev_run
|
||||
|
|
|
@ -121,7 +121,8 @@ When this flag is enabled each batch is split into sequences of size truncated_b
|
|||
|
||||
NaN detection and intervention
|
||||
------------------------------
|
||||
In every forward pass in training, Lightning will check that
|
||||
When the `terminate_on_nan` flag is enabled, after every forward pass during training, Lightning will
|
||||
check that
|
||||
|
||||
1. the loss you return in `training_step` is finite (not NaN and not +/-inf)
|
||||
2. the model parameters have finite values.
|
||||
|
@ -130,6 +131,14 @@ Lightning will terminate the training loop with an error message if NaN or infin
|
|||
values are detected. If this happens, you should investigate numerically unstable operations
|
||||
in your model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT (won't perform the NaN check)
|
||||
trainer = Trainer(terminate_on_nan=False)
|
||||
|
||||
# (NaN check each batch and terminate on NaN or infinite values)
|
||||
trainer = Trainer(terminate_on_nan=True)
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
@ -216,6 +225,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
min_steps: int
|
||||
total_batch_idx: int
|
||||
checkpoint_callback: ...
|
||||
terminate_on_nan: bool
|
||||
|
||||
# Callback system
|
||||
callbacks: List[Callback]
|
||||
|
@ -604,7 +614,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
loss, batch_output = optimizer_closure()
|
||||
|
||||
# check if loss or model weights are nan
|
||||
self.detect_nan_tensors(loss)
|
||||
if self.terminate_on_nan:
|
||||
self.detect_nan_tensors(loss)
|
||||
|
||||
# track total loss for logging (avoid mem leaks)
|
||||
self.batch_loss_value.append(loss)
|
||||
|
|
|
@ -593,6 +593,7 @@ def test_nan_loss_detection(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=(test_step + 1),
|
||||
terminate_on_nan=True
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
|
||||
|
@ -619,6 +620,7 @@ def test_nan_params_detection(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=(test_step + 1),
|
||||
terminate_on_nan=True
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
|
||||
|
|
Loading…
Reference in New Issue