Deprecate `terminate_on_nan` Trainer argument in favor of `detect_anomaly` (#9175)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
yopknopixx 2021-10-11 22:47:43 +05:30 committed by GitHub
parent 6a0c47a014
commit 173f4c8466
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 7 deletions

View File

@ -271,6 +271,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`

View File

@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Optional, Union
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -26,10 +26,15 @@ class TrainingTricksConnector:
gradient_clip_val: Union[int, float],
gradient_clip_algorithm: str,
track_grad_norm: Union[int, float, str],
terminate_on_nan: bool,
terminate_on_nan: Optional[bool],
):
if not isinstance(terminate_on_nan, bool):
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")
if terminate_on_nan is not None:
rank_zero_deprecation(
"Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7."
" Please use `Trainer(detect_anomaly=True)` instead."
)
if not isinstance(terminate_on_nan, bool):
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")
# gradient clipping
if not isinstance(gradient_clip_val, (int, float)):

View File

@ -167,7 +167,7 @@ class Trainer(
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
terminate_on_nan: bool = False,
detect_anomaly: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: Optional[bool] = None,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
@ -177,7 +177,7 @@ class Trainer(
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
stochastic_weight_avg: bool = False,
detect_anomaly: bool = False,
terminate_on_nan: Optional[bool] = None,
):
r"""
Customize every aspect of training via flags.
@ -351,6 +351,12 @@ class Trainer(
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.
.. deprecated:: v1.5
Trainer argument ``terminate_on_nan`` was deprecated in v1.5 and will be removed in 1.7.
Please use ``detect_anomaly`` instead.
detect_anomaly: Enable anomaly detection for the autograd engine.
tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]
ipus: How many IPUs to train on.

View File

@ -122,6 +122,16 @@ def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir):
_ = Trainer(stochastic_weight_avg=True)
@pytest.mark.parametrize("terminate_on_nan", [True, False])
def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
with pytest.deprecated_call(
match="Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7"
):
trainer = Trainer(terminate_on_nan=terminate_on_nan)
assert trainer.terminate_on_nan is terminate_on_nan
assert trainer._detect_anomaly is False
def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
class CustomBoringModel(BoringModel):
def on_train_dataloader(self):