[Test] Add extra test for val_check_interval in distributed scenario (#7863)
* add extra test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add computation * Update docs/source/common/trainer.rst Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update docs/source/common/trainer.rst Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update tests/trainer/test_dataloaders.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * use tmpdir * update on comments * update * Update tests/callbacks/test_progress_bar.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
6388c29e87
commit
ea71cf4a5f
|
@ -1527,6 +1527,24 @@ Can specify as float or int.
|
|||
trainer = Trainer(val_check_interval=1000)
|
||||
|
||||
|
||||
.. code-block::
|
||||
|
||||
# Here is the computation to estimate the total number of batches seen within an epoch.
|
||||
|
||||
# Find the total number of train batches
|
||||
total_train_batches = total_train_samples // (train_batch_size * world_size)
|
||||
|
||||
# Compute how many times we will call validation during the training loop
|
||||
val_check_batch = max(1, int(total_train_batches * val_check_interval))
|
||||
val_checks_per_epoch = total_train_batches / val_check_batch
|
||||
|
||||
# Find the total number of validation batches
|
||||
total_val_batches = total_val_samples // (val_batch_size * world_size)
|
||||
|
||||
# Total number of batches run
|
||||
total_fit_batches = total_train_batches + total_val_batches
|
||||
|
||||
|
||||
weights_save_path
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -20,12 +20,14 @@ from unittest.mock import ANY, call, Mock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
|
||||
from pytorch_lightning.callbacks.progress import tqdm
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -533,3 +535,47 @@ def test_progress_bar_can_be_pickled():
|
|||
pickle.dumps(bar)
|
||||
trainer.predict(model)
|
||||
pickle.dumps(bar)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, special=True)
|
||||
@pytest.mark.parametrize([
|
||||
"total_train_samples",
|
||||
"train_batch_size",
|
||||
"total_val_samples",
|
||||
"val_batch_size",
|
||||
"val_check_interval",
|
||||
], [
|
||||
(8, 4, 2, 1, 0.2),
|
||||
(8, 4, 2, 1, 0.5),
|
||||
])
|
||||
def test_progress_bar_max_val_check_interval(
|
||||
total_train_samples, train_batch_size, total_val_samples, val_batch_size, val_check_interval, tmpdir
|
||||
):
|
||||
|
||||
world_size = 2
|
||||
|
||||
train_data = DataLoader(RandomDataset(32, total_train_samples), batch_size=train_batch_size)
|
||||
val_data = DataLoader(RandomDataset(32, total_val_samples), batch_size=val_batch_size)
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=0,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
val_check_interval=val_check_interval,
|
||||
gpus=world_size,
|
||||
accelerator="ddp",
|
||||
)
|
||||
trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
|
||||
|
||||
total_train_batches = total_train_samples // (train_batch_size * world_size)
|
||||
val_check_batch = max(1, int(total_train_batches * val_check_interval))
|
||||
assert trainer.val_check_batch == val_check_batch
|
||||
val_checks_per_epoch = total_train_batches / val_check_batch
|
||||
total_val_batches = total_val_samples // (val_batch_size * world_size)
|
||||
assert trainer.progress_bar_callback.total_train_batches == total_train_batches
|
||||
assert trainer.progress_bar_callback.total_val_batches == total_val_batches
|
||||
total_val_batches = total_val_batches * val_checks_per_epoch
|
||||
if trainer.is_global_zero:
|
||||
assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches
|
||||
|
|
Loading…
Reference in New Issue