Bugfix/18394 batch size finder max val batches (#18854)
Co-authored-by: Oleksandra Sokol <o.sokol@samsung.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
874825857f
commit
e50b68aae3
|
@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue when replacing an existing `last.ckpt` file with a symlink ([#18793](https://github.com/Lightning-AI/lightning/pull/18793))
|
||||
|
||||
|
||||
- Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/lightning/issues/18394))
|
||||
|
||||
|
||||
|
||||
## [2.1.0] - 2023-10-11
|
||||
|
||||
|
|
|
@ -323,6 +323,9 @@ def _reset_dataloaders(trainer: "pl.Trainer") -> None:
|
|||
assert loop is not None
|
||||
loop._combined_loader = None # force a reload
|
||||
loop.setup_data()
|
||||
if isinstance(loop, pl.loops._FitLoop):
|
||||
loop.epoch_loop.val_loop._combined_loader = None
|
||||
loop.epoch_loop.val_loop.setup_data()
|
||||
|
||||
|
||||
def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
|
||||
|
|
|
@ -317,7 +317,7 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method,
|
|||
assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len)
|
||||
|
||||
assert trainer.train_dataloader.batch_size == new_batch_size
|
||||
assert trainer.val_dataloaders.batch_size == init_batch_size
|
||||
assert trainer.val_dataloaders.batch_size == new_batch_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
|
||||
|
@ -469,3 +469,20 @@ def test_dataloader_batch_size_updated_on_failure(_, tmpdir, scale_method, expec
|
|||
assert new_batch_size == model.batch_size
|
||||
assert new_batch_size == expected_batch_size
|
||||
assert trainer.train_dataloader.batch_size == expected_batch_size
|
||||
|
||||
|
||||
def test_batch_size_finder_callback_val_batches(tmpdir):
|
||||
"""Test that `BatchSizeFinder` does not limit the number of val batches during training."""
|
||||
steps_per_trial = 2
|
||||
model = BatchSizeModel(batch_size=16)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=0,
|
||||
max_epochs=1,
|
||||
enable_model_summary=False,
|
||||
callbacks=[BatchSizeFinder(steps_per_trial=steps_per_trial, max_trials=1)],
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
|
||||
assert trainer.num_val_batches[0] != steps_per_trial
|
||||
|
|
Loading…
Reference in New Issue