Fix typing in `pl.trainer.config_validator` (#10803)
This commit is contained in:
parent
2fc64e9656
commit
9a4b51d17f
|
@ -94,7 +94,6 @@ module = [
|
||||||
"pytorch_lightning.profiler.pytorch",
|
"pytorch_lightning.profiler.pytorch",
|
||||||
"pytorch_lightning.profiler.simple",
|
"pytorch_lightning.profiler.simple",
|
||||||
"pytorch_lightning.trainer.callback_hook",
|
"pytorch_lightning.trainer.callback_hook",
|
||||||
"pytorch_lightning.trainer.configuration_validator",
|
|
||||||
"pytorch_lightning.trainer.connectors.accelerator_connector",
|
"pytorch_lightning.trainer.connectors.accelerator_connector",
|
||||||
"pytorch_lightning.trainer.connectors.callback_connector",
|
"pytorch_lightning.trainer.connectors.callback_connector",
|
||||||
"pytorch_lightning.trainer.connectors.checkpoint_connector",
|
"pytorch_lightning.trainer.connectors.checkpoint_connector",
|
||||||
|
|
|
@ -28,6 +28,8 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
|
||||||
model: The model to check the configuration.
|
model: The model to check the configuration.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if trainer.state.fn is None:
|
||||||
|
raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.")
|
||||||
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
|
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
|
||||||
__verify_train_val_loop_configuration(trainer, model)
|
__verify_train_val_loop_configuration(trainer, model)
|
||||||
__verify_manual_optimization_support(trainer, model)
|
__verify_manual_optimization_support(trainer, model)
|
||||||
|
@ -221,7 +223,7 @@ def __verify_manual_optimization_support(trainer: "pl.Trainer", model: "pl.Light
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule"):
|
def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule") -> None:
|
||||||
"""Check if the current `training_step` is requesting `dataloader_iter`."""
|
"""Check if the current `training_step` is requesting `dataloader_iter`."""
|
||||||
training_step_fx = model.training_step
|
training_step_fx = model.training_step
|
||||||
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
|
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
|
||||||
|
|
|
@ -551,7 +551,7 @@ class Trainer(
|
||||||
)
|
)
|
||||||
|
|
||||||
self._terminate_on_nan = terminate_on_nan
|
self._terminate_on_nan = terminate_on_nan
|
||||||
self.gradient_clip_val = gradient_clip_val
|
self.gradient_clip_val: Union[int, float] = gradient_clip_val
|
||||||
self.gradient_clip_algorithm = (
|
self.gradient_clip_algorithm = (
|
||||||
GradClipAlgorithmType(gradient_clip_algorithm.lower())
|
GradClipAlgorithmType(gradient_clip_algorithm.lower())
|
||||||
if gradient_clip_algorithm is not None
|
if gradient_clip_algorithm is not None
|
||||||
|
|
Loading…
Reference in New Issue