Update Trainer's ckpt_path type for pathlib Path (#19362)

This commit is contained in:
awaelchli 2024-01-30 00:42:18 +01:00 committed by GitHub
parent b0e1ee2469
commit bcc8de8dec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 8 deletions

View File

@ -506,7 +506,7 @@ class Trainer:
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None, datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None, ckpt_path: Optional[_PATH] = None,
) -> None: ) -> None:
r"""Runs the full optimization routine. r"""Runs the full optimization routine.
@ -550,7 +550,7 @@ class Trainer:
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None, datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None, ckpt_path: Optional[_PATH] = None,
) -> None: ) -> None:
log.debug(f"{self.__class__.__name__}: trainer fit stage") log.debug(f"{self.__class__.__name__}: trainer fit stage")
@ -586,7 +586,7 @@ class Trainer:
self, self,
model: Optional["pl.LightningModule"] = None, model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = None, ckpt_path: Optional[_PATH] = None,
verbose: bool = True, verbose: bool = True,
datamodule: Optional[LightningDataModule] = None, datamodule: Optional[LightningDataModule] = None,
) -> _EVALUATE_OUTPUT: ) -> _EVALUATE_OUTPUT:
@ -649,7 +649,7 @@ class Trainer:
self, self,
model: Optional["pl.LightningModule"] = None, model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = None, ckpt_path: Optional[_PATH] = None,
verbose: bool = True, verbose: bool = True,
datamodule: Optional[LightningDataModule] = None, datamodule: Optional[LightningDataModule] = None,
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
@ -694,7 +694,7 @@ class Trainer:
self, self,
model: Optional["pl.LightningModule"] = None, model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = None, ckpt_path: Optional[_PATH] = None,
verbose: bool = True, verbose: bool = True,
datamodule: Optional[LightningDataModule] = None, datamodule: Optional[LightningDataModule] = None,
) -> _EVALUATE_OUTPUT: ) -> _EVALUATE_OUTPUT:
@ -758,7 +758,7 @@ class Trainer:
self, self,
model: Optional["pl.LightningModule"] = None, model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = None, ckpt_path: Optional[_PATH] = None,
verbose: bool = True, verbose: bool = True,
datamodule: Optional[LightningDataModule] = None, datamodule: Optional[LightningDataModule] = None,
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
@ -805,7 +805,7 @@ class Trainer:
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
datamodule: Optional[LightningDataModule] = None, datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None, return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = None, ckpt_path: Optional[_PATH] = None,
) -> Optional[_PREDICT_OUTPUT]: ) -> Optional[_PREDICT_OUTPUT]:
r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to
perform distributed and batched predictions. Logging is disabled in the predict hooks. perform distributed and batched predictions. Logging is disabled in the predict hooks.
@ -870,7 +870,7 @@ class Trainer:
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
datamodule: Optional[LightningDataModule] = None, datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None, return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = None, ckpt_path: Optional[_PATH] = None,
) -> Optional[_PREDICT_OUTPUT]: ) -> Optional[_PREDICT_OUTPUT]:
# -------------------- # --------------------
# SETUP HOOK # SETUP HOOK