From 9a4b51d17f4f78b7179cc5ca8a5070ef7aac5ed8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 6 Dec 2021 11:19:36 +0100 Subject: [PATCH] Fix typing in `pl.trainer.config_validator` (#10803) --- pyproject.toml | 1 - pytorch_lightning/trainer/configuration_validator.py | 4 +++- pytorch_lightning/trainer/trainer.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b44152ac9b..14d4cf9370 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,6 @@ module = [ "pytorch_lightning.profiler.pytorch", "pytorch_lightning.profiler.simple", "pytorch_lightning.trainer.callback_hook", - "pytorch_lightning.trainer.configuration_validator", "pytorch_lightning.trainer.connectors.accelerator_connector", "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.checkpoint_connector", diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index c44529c539..8f61a050cf 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -28,6 +28,8 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule model: The model to check the configuration. """ + if trainer.state.fn is None: + raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.") if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): __verify_train_val_loop_configuration(trainer, model) __verify_manual_optimization_support(trainer, model) @@ -221,7 +223,7 @@ def __verify_manual_optimization_support(trainer: "pl.Trainer", model: "pl.Light ) -def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule"): +def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule") -> None: """Check if the current `training_step` is requesting `dataloader_iter`.""" training_step_fx = model.training_step if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f518262445..274a19333f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -551,7 +551,7 @@ class Trainer( ) self._terminate_on_nan = terminate_on_nan - self.gradient_clip_val = gradient_clip_val + self.gradient_clip_val: Union[int, float] = gradient_clip_val self.gradient_clip_algorithm = ( GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None