Fix type hints of tuner/batch_size_scaling.py (#13518)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: otaj <ota@lightning.ai>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
Masahiro Wada 2022-09-29 21:00:42 +09:00 committed by GitHub
parent 136d57312d
commit d377d0efde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 6 deletions

View File

@ -59,7 +59,6 @@ module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.trainer.connectors.checkpoint_connector",
"pytorch_lightning.tuner.batch_size_scaling",
"lightning_app.api.http_methods",
"lightning_app.api.request_types",
"lightning_app.cli.app-template.app",

View File

@ -31,6 +31,8 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_warn
class BatchSizeFinder(Callback):
SUPPORTED_MODES = ("power", "binsearch")
optimal_batch_size: Optional[int]
def __init__(
self,
mode: str = "power",

View File

@ -35,10 +35,10 @@ def scale_batch_size(
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = "batch_size",
):
) -> Optional[int]:
if trainer.fast_dev_run:
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
return
return None
# Save initial model, that is loaded after batch size is found
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
@ -141,7 +141,12 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any])
def _run_power_scaling(
trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
new_size: int,
batch_arg_name: str,
max_trials: int,
params: Dict[str, Any],
) -> int:
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
@ -179,7 +184,12 @@ def _run_power_scaling(
def _run_binary_scaling(
trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
new_size: int,
batch_arg_name: str,
max_trials: int,
params: Dict[str, Any],
) -> int:
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is
encountered.
@ -309,7 +319,7 @@ def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
reset_fn(pl_module)
def _try_loop_run(trainer: "pl.Trainer", params) -> None:
def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
if trainer.state.fn == "fit":
loop = trainer.fit_loop
else: