Reset epoch progress with batch size scaler (#13846)
Co-authored-by: Christian Schell <christian.schell@uni-wuerzburg.de> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
c418828d41
commit
70deac2cd4
|
@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed wrong num padding for `RichProgressBar` ([#14296](https://github.com/Lightning-AI/lightning/pull/14296))
|
||||
|
||||
|
||||
- Reset epoch progress with batch size scaler ([#13846](https://github.com/Lightning-AI/lightning/pull/13846)
|
||||
|
||||
|
||||
- Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806))
|
||||
|
||||
|
||||
|
|
|
@ -128,7 +128,10 @@ def _run_power_scaling(
|
|||
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
|
||||
for _ in range(max_trials):
|
||||
garbage_collection_cuda()
|
||||
trainer.fit_loop.global_step = 0 # reset after each try
|
||||
|
||||
# reset after each try
|
||||
_reset_progress(trainer)
|
||||
|
||||
try:
|
||||
# Try fit
|
||||
trainer.tuner._run(model)
|
||||
|
@ -166,7 +169,10 @@ def _run_binsearch_scaling(
|
|||
count = 0
|
||||
while True:
|
||||
garbage_collection_cuda()
|
||||
trainer.fit_loop.global_step = 0 # reset after each try
|
||||
|
||||
# reset after each try
|
||||
_reset_progress(trainer)
|
||||
|
||||
try:
|
||||
# Try fit
|
||||
trainer.tuner._run(model)
|
||||
|
@ -249,3 +255,12 @@ def _adjust_batch_size(
|
|||
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"):
|
||||
module = trainer.lightning_module or trainer.datamodule
|
||||
return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader)
|
||||
|
||||
|
||||
def _reset_progress(trainer: "pl.Trainer") -> None:
|
||||
if trainer.lightning_module.automatic_optimization:
|
||||
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.reset()
|
||||
else:
|
||||
trainer.fit_loop.epoch_loop.batch_loop.manual_loop.optim_step_progress.reset()
|
||||
|
||||
trainer.fit_loop.epoch_progress.reset()
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -308,10 +309,13 @@ def test_scale_batch_size_fails_with_unavailable_mode(tmpdir):
|
|||
def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):
|
||||
"""Test that train and val dataloaders are reset at every update in scale batch size."""
|
||||
model = BatchSizeModel(batch_size=16)
|
||||
scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4, "mode": scale_method}
|
||||
max_trials = 5
|
||||
scale_batch_size_kwargs = {"max_trials": max_trials, "steps_per_trial": 2, "init_val": 4, "mode": scale_method}
|
||||
|
||||
trainer = Trainer(max_epochs=2, auto_scale_batch_size=True)
|
||||
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
|
||||
with patch.object(model, "on_train_epoch_end") as advance_mocked:
|
||||
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
|
||||
assert advance_mocked.call_count == max_trials
|
||||
|
||||
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
|
||||
assert trainer.val_dataloaders[0].batch_size == new_batch_size
|
||||
|
|
Loading…
Reference in New Issue