Fix typing in `pl.trainer.config_validator` (#10803)

This commit is contained in:
Adrian Wälchli 2021-12-06 11:19:36 +01:00 committed by GitHub
parent 2fc64e9656
commit 9a4b51d17f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 3 deletions

View File

@ -94,7 +94,6 @@ module = [
"pytorch_lightning.profiler.pytorch", "pytorch_lightning.profiler.pytorch",
"pytorch_lightning.profiler.simple", "pytorch_lightning.profiler.simple",
"pytorch_lightning.trainer.callback_hook", "pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.configuration_validator",
"pytorch_lightning.trainer.connectors.accelerator_connector", "pytorch_lightning.trainer.connectors.accelerator_connector",
"pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.callback_connector",
"pytorch_lightning.trainer.connectors.checkpoint_connector", "pytorch_lightning.trainer.connectors.checkpoint_connector",

View File

@ -28,6 +28,8 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
model: The model to check the configuration. model: The model to check the configuration.
""" """
if trainer.state.fn is None:
raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.")
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
__verify_train_val_loop_configuration(trainer, model) __verify_train_val_loop_configuration(trainer, model)
__verify_manual_optimization_support(trainer, model) __verify_manual_optimization_support(trainer, model)
@ -221,7 +223,7 @@ def __verify_manual_optimization_support(trainer: "pl.Trainer", model: "pl.Light
) )
def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule"): def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule") -> None:
"""Check if the current `training_step` is requesting `dataloader_iter`.""" """Check if the current `training_step` is requesting `dataloader_iter`."""
training_step_fx = model.training_step training_step_fx = model.training_step
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):

View File

@ -551,7 +551,7 @@ class Trainer(
) )
self._terminate_on_nan = terminate_on_nan self._terminate_on_nan = terminate_on_nan
self.gradient_clip_val = gradient_clip_val self.gradient_clip_val: Union[int, float] = gradient_clip_val
self.gradient_clip_algorithm = ( self.gradient_clip_algorithm = (
GradClipAlgorithmType(gradient_clip_algorithm.lower()) GradClipAlgorithmType(gradient_clip_algorithm.lower())
if gradient_clip_algorithm is not None if gradient_clip_algorithm is not None