diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 32303d6bab..8dcd45f58b 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrong num padding for `RichProgressBar` ([#14296](https://github.com/Lightning-AI/lightning/pull/14296)) +- Reset epoch progress with batch size scaler ([#13846](https://github.com/Lightning-AI/lightning/pull/13846) + + - Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806)) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 316fc5a219..a1f8a2de4b 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -128,7 +128,10 @@ def _run_power_scaling( """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" for _ in range(max_trials): garbage_collection_cuda() - trainer.fit_loop.global_step = 0 # reset after each try + + # reset after each try + _reset_progress(trainer) + try: # Try fit trainer.tuner._run(model) @@ -166,7 +169,10 @@ def _run_binsearch_scaling( count = 0 while True: garbage_collection_cuda() - trainer.fit_loop.global_step = 0 # reset after each try + + # reset after each try + _reset_progress(trainer) + try: # Try fit trainer.tuner._run(model) @@ -249,3 +255,12 @@ def _adjust_batch_size( def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"): module = trainer.lightning_module or trainer.datamodule return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) + + +def _reset_progress(trainer: "pl.Trainer") -> None: + if trainer.lightning_module.automatic_optimization: + trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.reset() + else: + trainer.fit_loop.epoch_loop.batch_loop.manual_loop.optim_step_progress.reset() + + trainer.fit_loop.epoch_progress.reset() diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index ce7c3613f5..e703b37491 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -13,6 +13,7 @@ # limitations under the License. import os from copy import deepcopy +from unittest.mock import patch import pytest import torch @@ -308,10 +309,13 @@ def test_scale_batch_size_fails_with_unavailable_mode(tmpdir): def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method): """Test that train and val dataloaders are reset at every update in scale batch size.""" model = BatchSizeModel(batch_size=16) - scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4, "mode": scale_method} + max_trials = 5 + scale_batch_size_kwargs = {"max_trials": max_trials, "steps_per_trial": 2, "init_val": 4, "mode": scale_method} - trainer = Trainer(max_epochs=2, auto_scale_batch_size=True) - new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"] + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True) + with patch.object(model, "on_train_epoch_end") as advance_mocked: + new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"] + assert advance_mocked.call_count == max_trials assert trainer.train_dataloader.loaders.batch_size == new_batch_size assert trainer.val_dataloaders[0].batch_size == new_batch_size