Fix type hints of tuner/batch_size_scaling.py (#13518)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: otaj <ota@lightning.ai> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
136d57312d
commit
d377d0efde
|
@ -59,7 +59,6 @@ module = [
|
|||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
"pytorch_lightning.trainer.trainer",
|
||||
"pytorch_lightning.trainer.connectors.checkpoint_connector",
|
||||
"pytorch_lightning.tuner.batch_size_scaling",
|
||||
"lightning_app.api.http_methods",
|
||||
"lightning_app.api.request_types",
|
||||
"lightning_app.cli.app-template.app",
|
||||
|
|
|
@ -31,6 +31,8 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
|||
class BatchSizeFinder(Callback):
|
||||
SUPPORTED_MODES = ("power", "binsearch")
|
||||
|
||||
optimal_batch_size: Optional[int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = "power",
|
||||
|
|
|
@ -35,10 +35,10 @@ def scale_batch_size(
|
|||
init_val: int = 2,
|
||||
max_trials: int = 25,
|
||||
batch_arg_name: str = "batch_size",
|
||||
):
|
||||
) -> Optional[int]:
|
||||
if trainer.fast_dev_run:
|
||||
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
|
||||
return
|
||||
return None
|
||||
|
||||
# Save initial model, that is loaded after batch size is found
|
||||
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
|
||||
|
@ -141,7 +141,12 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any])
|
|||
|
||||
|
||||
def _run_power_scaling(
|
||||
trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
new_size: int,
|
||||
batch_arg_name: str,
|
||||
max_trials: int,
|
||||
params: Dict[str, Any],
|
||||
) -> int:
|
||||
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
|
||||
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
|
||||
|
@ -179,7 +184,12 @@ def _run_power_scaling(
|
|||
|
||||
|
||||
def _run_binary_scaling(
|
||||
trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
new_size: int,
|
||||
batch_arg_name: str,
|
||||
max_trials: int,
|
||||
params: Dict[str, Any],
|
||||
) -> int:
|
||||
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is
|
||||
encountered.
|
||||
|
@ -309,7 +319,7 @@ def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
|
|||
reset_fn(pl_module)
|
||||
|
||||
|
||||
def _try_loop_run(trainer: "pl.Trainer", params) -> None:
|
||||
def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
|
||||
if trainer.state.fn == "fit":
|
||||
loop = trainer.fit_loop
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue