Reset epoch progress with batch size scaler (#13846)

Co-authored-by: Christian Schell <christian.schell@uni-wuerzburg.de>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Christian Schell 2022-08-26 10:42:00 +02:00 committed by GitHub
parent c418828d41
commit 70deac2cd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 5 deletions

View File

@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed wrong num padding for `RichProgressBar` ([#14296](https://github.com/Lightning-AI/lightning/pull/14296))
- Reset epoch progress with batch size scaler ([#13846](https://github.com/Lightning-AI/lightning/pull/13846)
- Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806))

View File

@ -128,7 +128,10 @@ def _run_power_scaling(
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
for _ in range(max_trials):
garbage_collection_cuda()
trainer.fit_loop.global_step = 0 # reset after each try
# reset after each try
_reset_progress(trainer)
try:
# Try fit
trainer.tuner._run(model)
@ -166,7 +169,10 @@ def _run_binsearch_scaling(
count = 0
while True:
garbage_collection_cuda()
trainer.fit_loop.global_step = 0 # reset after each try
# reset after each try
_reset_progress(trainer)
try:
# Try fit
trainer.tuner._run(model)
@ -249,3 +255,12 @@ def _adjust_batch_size(
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"):
module = trainer.lightning_module or trainer.datamodule
return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader)
def _reset_progress(trainer: "pl.Trainer") -> None:
if trainer.lightning_module.automatic_optimization:
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.reset()
else:
trainer.fit_loop.epoch_loop.batch_loop.manual_loop.optim_step_progress.reset()
trainer.fit_loop.epoch_progress.reset()

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
from copy import deepcopy
from unittest.mock import patch
import pytest
import torch
@ -308,10 +309,13 @@ def test_scale_batch_size_fails_with_unavailable_mode(tmpdir):
def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):
"""Test that train and val dataloaders are reset at every update in scale batch size."""
model = BatchSizeModel(batch_size=16)
scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4, "mode": scale_method}
max_trials = 5
scale_batch_size_kwargs = {"max_trials": max_trials, "steps_per_trial": 2, "init_val": 4, "mode": scale_method}
trainer = Trainer(max_epochs=2, auto_scale_batch_size=True)
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
with patch.object(model, "on_train_epoch_end") as advance_mocked:
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
assert advance_mocked.call_count == max_trials
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
assert trainer.val_dataloaders[0].batch_size == new_batch_size