diff --git a/pyproject.toml b/pyproject.toml index 497a32eab2..1bf681b5c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ module = [ "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.trainer.trainer", "pytorch_lightning.trainer.connectors.checkpoint_connector", - "pytorch_lightning.tuner.batch_size_scaling", "lightning_app.api.http_methods", "lightning_app.api.request_types", "lightning_app.cli.app-template.app", diff --git a/src/pytorch_lightning/callbacks/batch_size_finder.py b/src/pytorch_lightning/callbacks/batch_size_finder.py index d4a8d37da4..96b9f6eef8 100644 --- a/src/pytorch_lightning/callbacks/batch_size_finder.py +++ b/src/pytorch_lightning/callbacks/batch_size_finder.py @@ -31,6 +31,8 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_warn class BatchSizeFinder(Callback): SUPPORTED_MODES = ("power", "binsearch") + optimal_batch_size: Optional[int] + def __init__( self, mode: str = "power", diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index a85ef6a814..781c7ee119 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -35,10 +35,10 @@ def scale_batch_size( init_val: int = 2, max_trials: int = 25, batch_arg_name: str = "batch_size", -): +) -> Optional[int]: if trainer.fast_dev_run: rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.") - return + return None # Save initial model, that is loaded after batch size is found ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") @@ -141,7 +141,12 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) def _run_power_scaling( - trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + new_size: int, + batch_arg_name: str, + max_trials: int, + params: Dict[str, Any], ) -> int: """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not @@ -179,7 +184,12 @@ def _run_power_scaling( def _run_binary_scaling( - trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + new_size: int, + batch_arg_name: str, + max_trials: int, + params: Dict[str, Any], ) -> int: """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. @@ -309,7 +319,7 @@ def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") - reset_fn(pl_module) -def _try_loop_run(trainer: "pl.Trainer", params) -> None: +def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: if trainer.state.fn == "fit": loop = trainer.fit_loop else: