Fix typing in `pl.trainer.config_validator` (#10803)
This commit is contained in:
parent
2fc64e9656
commit
9a4b51d17f
|
@ -94,7 +94,6 @@ module = [
|
|||
"pytorch_lightning.profiler.pytorch",
|
||||
"pytorch_lightning.profiler.simple",
|
||||
"pytorch_lightning.trainer.callback_hook",
|
||||
"pytorch_lightning.trainer.configuration_validator",
|
||||
"pytorch_lightning.trainer.connectors.accelerator_connector",
|
||||
"pytorch_lightning.trainer.connectors.callback_connector",
|
||||
"pytorch_lightning.trainer.connectors.checkpoint_connector",
|
||||
|
|
|
@ -28,6 +28,8 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
|
|||
model: The model to check the configuration.
|
||||
|
||||
"""
|
||||
if trainer.state.fn is None:
|
||||
raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.")
|
||||
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
|
||||
__verify_train_val_loop_configuration(trainer, model)
|
||||
__verify_manual_optimization_support(trainer, model)
|
||||
|
@ -221,7 +223,7 @@ def __verify_manual_optimization_support(trainer: "pl.Trainer", model: "pl.Light
|
|||
)
|
||||
|
||||
|
||||
def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule"):
|
||||
def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule") -> None:
|
||||
"""Check if the current `training_step` is requesting `dataloader_iter`."""
|
||||
training_step_fx = model.training_step
|
||||
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
|
||||
|
|
|
@ -551,7 +551,7 @@ class Trainer(
|
|||
)
|
||||
|
||||
self._terminate_on_nan = terminate_on_nan
|
||||
self.gradient_clip_val = gradient_clip_val
|
||||
self.gradient_clip_val: Union[int, float] = gradient_clip_val
|
||||
self.gradient_clip_algorithm = (
|
||||
GradClipAlgorithmType(gradient_clip_algorithm.lower())
|
||||
if gradient_clip_algorithm is not None
|
||||
|
|
Loading…
Reference in New Issue