Deprecate `terminate_on_nan` Trainer argument in favor of `detect_anomaly` (#9175)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
6a0c47a014
commit
173f4c8466
|
@ -271,6 +271,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))
|
||||
|
||||
|
||||
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
|
||||
|
||||
|
||||
|
|
|
@ -11,9 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
@ -26,10 +26,15 @@ class TrainingTricksConnector:
|
|||
gradient_clip_val: Union[int, float],
|
||||
gradient_clip_algorithm: str,
|
||||
track_grad_norm: Union[int, float, str],
|
||||
terminate_on_nan: bool,
|
||||
terminate_on_nan: Optional[bool],
|
||||
):
|
||||
if not isinstance(terminate_on_nan, bool):
|
||||
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")
|
||||
if terminate_on_nan is not None:
|
||||
rank_zero_deprecation(
|
||||
"Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7."
|
||||
" Please use `Trainer(detect_anomaly=True)` instead."
|
||||
)
|
||||
if not isinstance(terminate_on_nan, bool):
|
||||
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")
|
||||
|
||||
# gradient clipping
|
||||
if not isinstance(gradient_clip_val, (int, float)):
|
||||
|
|
|
@ -167,7 +167,7 @@ class Trainer(
|
|||
reload_dataloaders_every_epoch: bool = False,
|
||||
auto_lr_find: Union[bool, str] = False,
|
||||
replace_sampler_ddp: bool = True,
|
||||
terminate_on_nan: bool = False,
|
||||
detect_anomaly: bool = False,
|
||||
auto_scale_batch_size: Union[str, bool] = False,
|
||||
prepare_data_per_node: Optional[bool] = None,
|
||||
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
|
||||
|
@ -177,7 +177,7 @@ class Trainer(
|
|||
move_metrics_to_cpu: bool = False,
|
||||
multiple_trainloader_mode: str = "max_size_cycle",
|
||||
stochastic_weight_avg: bool = False,
|
||||
detect_anomaly: bool = False,
|
||||
terminate_on_nan: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Customize every aspect of training via flags.
|
||||
|
@ -351,6 +351,12 @@ class Trainer(
|
|||
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
|
||||
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.
|
||||
|
||||
.. deprecated:: v1.5
|
||||
Trainer argument ``terminate_on_nan`` was deprecated in v1.5 and will be removed in 1.7.
|
||||
Please use ``detect_anomaly`` instead.
|
||||
|
||||
detect_anomaly: Enable anomaly detection for the autograd engine.
|
||||
|
||||
tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]
|
||||
|
||||
ipus: How many IPUs to train on.
|
||||
|
|
|
@ -122,6 +122,16 @@ def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir):
|
|||
_ = Trainer(stochastic_weight_avg=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("terminate_on_nan", [True, False])
|
||||
def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
|
||||
with pytest.deprecated_call(
|
||||
match="Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7"
|
||||
):
|
||||
trainer = Trainer(terminate_on_nan=terminate_on_nan)
|
||||
assert trainer.terminate_on_nan is terminate_on_nan
|
||||
assert trainer._detect_anomaly is False
|
||||
|
||||
|
||||
def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
|
||||
class CustomBoringModel(BoringModel):
|
||||
def on_train_dataloader(self):
|
||||
|
|
Loading…
Reference in New Issue