1531 lines
59 KiB
Python
1531 lines
59 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from unittest.mock import call, Mock, patch
|
|
|
|
import numpy
|
|
import pytest
|
|
import torch
|
|
from torch.utils.data import RandomSampler
|
|
from torch.utils.data.dataloader import DataLoader
|
|
from torch.utils.data.dataset import Dataset, IterableDataset, Subset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.data.sampler import SequentialSampler
|
|
|
|
import tests.helpers.pipelines as tpipes
|
|
from pytorch_lightning import Callback, seed_everything, Trainer
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset, has_len_all_ranks
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from tests.base import EvalModelTemplate
|
|
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset, RandomIterableDatasetWithLen
|
|
from tests.helpers.runif import RunIf
|
|
|
|
|
|
def test_fit_train_loader_only(tmpdir):
|
|
model = EvalModelTemplate()
|
|
train_dataloader = model.train_dataloader()
|
|
|
|
model.train_dataloader = None
|
|
model.val_dataloader = None
|
|
model.test_dataloader = None
|
|
|
|
model.validation_step = None
|
|
model.validation_epoch_end = None
|
|
|
|
model.test_step = None
|
|
model.test_epoch_end = None
|
|
|
|
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
|
|
trainer.fit(model, train_dataloaders=train_dataloader)
|
|
|
|
|
|
def test_fit_val_loader_only(tmpdir):
|
|
model = EvalModelTemplate()
|
|
train_dataloader = model.train_dataloader()
|
|
val_dataloader = model.val_dataloader()
|
|
|
|
model.train_dataloader = None
|
|
model.val_dataloader = None
|
|
model.test_dataloader = None
|
|
|
|
model.test_step = None
|
|
model.test_epoch_end = None
|
|
|
|
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
|
|
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
|
|
|
|
|
|
@pytest.mark.parametrize("dataloader_options", [dict(val_check_interval=10000)])
|
|
def test_dataloader_config_errors_runtime(tmpdir, dataloader_options):
|
|
model = EvalModelTemplate()
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, **dataloader_options)
|
|
with pytest.raises(ValueError):
|
|
# fit model
|
|
trainer.fit(model)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"dataloader_options",
|
|
[
|
|
dict(limit_train_batches=-0.1),
|
|
dict(limit_train_batches=1.2),
|
|
dict(limit_val_batches=-0.1),
|
|
dict(limit_val_batches=1.2),
|
|
dict(limit_test_batches=-0.1),
|
|
dict(limit_test_batches=1.2),
|
|
dict(val_check_interval=-0.1),
|
|
dict(val_check_interval=1.2),
|
|
dict(overfit_batches=-0.1),
|
|
dict(overfit_batches=1.2),
|
|
],
|
|
)
|
|
def test_dataloader_config_errors_init(tmpdir, dataloader_options):
|
|
with pytest.raises(MisconfigurationException, match="passed invalid value"):
|
|
Trainer(default_root_dir=tmpdir, max_epochs=1, **dataloader_options)
|
|
|
|
|
|
def test_multiple_val_dataloader(tmpdir):
|
|
"""Verify multiple val_dataloader."""
|
|
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__multiple
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
|
|
# fit model
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=1.0)
|
|
trainer.fit(model)
|
|
|
|
# verify training completed
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
|
|
# verify there are 2 val loaders
|
|
assert len(trainer.val_dataloaders) == 2, "Multiple val_dataloaders not initiated properly"
|
|
|
|
# make sure predictions are good for each val set
|
|
for dataloader in trainer.val_dataloaders:
|
|
tpipes.run_prediction_eval_model_template(trained_model=model, dataloader=dataloader)
|
|
|
|
|
|
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
|
|
def test_multiple_eval_dataloader(tmpdir, ckpt_path):
|
|
"""Verify multiple evaluation dataloaders."""
|
|
|
|
class MultipleTestDataloaderModel(EvalModelTemplate):
|
|
def test_dataloader(self):
|
|
return [self.dataloader(train=False), self.dataloader(train=False)]
|
|
|
|
def test_step(self, *args, **kwargs):
|
|
return super().test_step__multiple_dataloaders(*args, **kwargs)
|
|
|
|
def val_dataloader(self):
|
|
return self.test_dataloader()
|
|
|
|
def validation_step(self, *args, **kwargs):
|
|
output = self.test_step(*args, **kwargs)
|
|
return {k.replace("test_", "val_"): v for k, v in output.items()}
|
|
|
|
model = MultipleTestDataloaderModel()
|
|
|
|
# fit model
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=10, limit_train_batches=100)
|
|
trainer.fit(model)
|
|
if ckpt_path == "specific":
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
|
|
trainer.validate(ckpt_path=ckpt_path, verbose=False)
|
|
# verify there are 2 loaders
|
|
assert len(trainer.val_dataloaders) == 2
|
|
# make sure predictions are good for each dl
|
|
for dataloader in trainer.val_dataloaders:
|
|
tpipes.run_prediction_eval_model_template(trainer.model, dataloader)
|
|
|
|
trainer.test(ckpt_path=ckpt_path, verbose=False)
|
|
assert len(trainer.test_dataloaders) == 2
|
|
for dataloader in trainer.test_dataloaders:
|
|
tpipes.run_prediction_eval_model_template(trainer.model, dataloader)
|
|
|
|
|
|
def test_train_dataloader_passed_to_fit(tmpdir):
|
|
"""Verify that train dataloader can be passed to fit."""
|
|
|
|
# only train passed to fit
|
|
model = EvalModelTemplate()
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
|
|
fit_options = dict(train_dataloaders=model.dataloader(train=True))
|
|
trainer.fit(model, **fit_options)
|
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
|
|
|
|
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
|
|
@pytest.mark.parametrize("n", (1, 2))
|
|
def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
|
|
"""Verify that dataloaders can be passed."""
|
|
|
|
model = EvalModelTemplate()
|
|
if n == 1:
|
|
dataloaders = model.dataloader(train=False)
|
|
else:
|
|
dataloaders = [model.dataloader(train=False)] * 2
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
model.test_step = model.test_step__multiple_dataloaders
|
|
|
|
# multiple val dataloaders passed to fit
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
|
|
trainer.fit(model, train_dataloaders=model.dataloader(train=True), val_dataloaders=dataloaders)
|
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
assert len(trainer.val_dataloaders) == n
|
|
|
|
if ckpt_path == "specific":
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
|
|
trainer.test(dataloaders=dataloaders, ckpt_path=ckpt_path)
|
|
assert len(trainer.test_dataloaders) == n
|
|
|
|
trainer.validate(dataloaders=dataloaders, ckpt_path=ckpt_path)
|
|
assert len(trainer.val_dataloaders) == n
|
|
|
|
|
|
class DummyModel(BoringModel):
|
|
def training_step(self, batch, batch_idx):
|
|
self.log("loss", self.global_step)
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
def validation_epoch_end(self, outputs):
|
|
self.log("val_log", self.current_epoch)
|
|
|
|
|
|
class Counter(Callback):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.train_epoch_count = 0
|
|
self.val_epoch_count = 0
|
|
self.test_epoch_count = 0
|
|
self.train_batches_seen = 0
|
|
self.val_batches_seen = 0
|
|
self.test_batches_seen = 0
|
|
|
|
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
|
self.train_batches_seen += 1
|
|
|
|
def on_train_epoch_start(self, trainer, pl_module):
|
|
self.train_epoch_count += 1
|
|
|
|
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
|
self.val_batches_seen += 1
|
|
|
|
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
|
self.test_batches_seen += 1
|
|
|
|
def on_validation_epoch_start(self, trainer, pl_module):
|
|
self.val_epoch_count += 1
|
|
|
|
def on_test_epoch_start(self, trainer, pl_module):
|
|
self.test_epoch_count += 1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["limit_train_batches", "limit_val_batches", "limit_test_batches"], [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
|
|
)
|
|
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
|
|
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent."""
|
|
|
|
ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False)
|
|
epoch_cb = Counter()
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
num_sanity_val_steps=0,
|
|
max_epochs=1,
|
|
callbacks=[epoch_cb, ckpt_callback],
|
|
limit_train_batches=limit_train_batches,
|
|
limit_val_batches=limit_val_batches,
|
|
limit_test_batches=limit_test_batches,
|
|
)
|
|
model = DummyModel()
|
|
|
|
batch_size = 8
|
|
train_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size)
|
|
val_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size)
|
|
test_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size)
|
|
|
|
num_batches = 128 / batch_size
|
|
for dl in (train_dl, val_dl, test_dl):
|
|
if has_len_all_ranks(dl, trainer.training_type_plugin, model):
|
|
assert len(dl) == num_batches
|
|
else:
|
|
assert sum(1 for _ in dl) == num_batches
|
|
|
|
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
assert trainer.num_training_batches == float("inf")
|
|
assert epoch_cb.train_epoch_count == int(limit_train_batches > 0)
|
|
# when limit_val_batches = 0, num_val_batches is empty as no data is loaded
|
|
if limit_val_batches != 0.0:
|
|
assert trainer.num_val_batches[0] == float("inf")
|
|
assert epoch_cb.val_epoch_count == int(limit_val_batches > 0)
|
|
|
|
trainer.test(model, dataloaders=test_dl)
|
|
assert trainer.num_test_batches[0] == (0 if limit_test_batches == 0.0 else float("inf"))
|
|
assert epoch_cb.test_epoch_count == int(limit_test_batches > 0)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["dataset", "limit_train_batches"],
|
|
[
|
|
(RandomDataset(32, 128), 0),
|
|
(RandomDataset(32, 128), 10),
|
|
(RandomIterableDataset(32, 128), 0),
|
|
(RandomIterableDataset(32, 128), 10),
|
|
(RandomIterableDatasetWithLen(32, 128), 0),
|
|
(RandomIterableDatasetWithLen(32, 128), 10),
|
|
],
|
|
)
|
|
def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batches):
|
|
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number."""
|
|
|
|
ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False)
|
|
epoch_cb = Counter()
|
|
epochs = 2
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
num_sanity_val_steps=0,
|
|
max_epochs=epochs,
|
|
callbacks=[epoch_cb, ckpt_callback],
|
|
limit_train_batches=limit_train_batches,
|
|
)
|
|
model = DummyModel()
|
|
|
|
batch_size = 8
|
|
train_dl = DataLoader(dataset=dataset, batch_size=batch_size)
|
|
val_dl = DataLoader(dataset=dataset, batch_size=batch_size)
|
|
|
|
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
assert trainer.num_training_batches == (limit_train_batches if limit_train_batches != 0.0 else float("inf"))
|
|
assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0)
|
|
assert epoch_cb.train_batches_seen == limit_train_batches * epochs
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["dataset", "limit_val_batches"],
|
|
[
|
|
(RandomDataset(32, 128), 0),
|
|
(RandomDataset(32, 128), 10),
|
|
(RandomIterableDataset(32, 128), 0),
|
|
(RandomIterableDataset(32, 128), 10),
|
|
(RandomIterableDatasetWithLen(32, 128), 0),
|
|
(RandomIterableDatasetWithLen(32, 128), 10),
|
|
],
|
|
)
|
|
def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches):
|
|
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number."""
|
|
|
|
epoch_cb = Counter()
|
|
callbacks = [epoch_cb]
|
|
enable_checkpointing = False
|
|
if limit_val_batches > 0:
|
|
callbacks.append(ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False))
|
|
enable_checkpointing = True
|
|
|
|
epochs = 2
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
num_sanity_val_steps=0,
|
|
max_epochs=epochs,
|
|
callbacks=callbacks,
|
|
limit_val_batches=limit_val_batches,
|
|
enable_checkpointing=enable_checkpointing,
|
|
)
|
|
model = DummyModel()
|
|
|
|
batch_size = 8
|
|
train_dl = DataLoader(dataset=dataset, batch_size=batch_size)
|
|
val_dl = DataLoader(dataset=dataset, batch_size=batch_size)
|
|
|
|
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
assert trainer.num_val_batches[0] == limit_val_batches
|
|
assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0)
|
|
assert epoch_cb.val_batches_seen == limit_val_batches * epochs
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["dataset", "limit_train_batches", "limit_val_batches", "limit_test_batches"],
|
|
[
|
|
(RandomDataset(32, 128), 0, 0, 0),
|
|
(RandomDataset(32, 128), 10, 10, 10),
|
|
(RandomIterableDataset(32, 128), 0, 0, 0),
|
|
(RandomIterableDataset(32, 128), 10, 10, 10),
|
|
(RandomIterableDatasetWithLen(32, 128), 0, 0, 0),
|
|
(RandomIterableDatasetWithLen(32, 128), 10, 10, 10),
|
|
],
|
|
)
|
|
def test_datasets_dataloaders_with_limit_num_batches(
|
|
tmpdir, dataset, limit_train_batches, limit_val_batches, limit_test_batches
|
|
):
|
|
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number."""
|
|
|
|
ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False)
|
|
epoch_cb = Counter()
|
|
epochs = 2
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
num_sanity_val_steps=0,
|
|
max_epochs=epochs,
|
|
callbacks=[epoch_cb, ckpt_callback],
|
|
limit_train_batches=limit_train_batches,
|
|
limit_val_batches=limit_val_batches,
|
|
limit_test_batches=limit_test_batches,
|
|
)
|
|
model = DummyModel()
|
|
|
|
batch_size = 8
|
|
train_dl = DataLoader(dataset=dataset, batch_size=batch_size)
|
|
val_dl = DataLoader(dataset=dataset, batch_size=batch_size)
|
|
test_dl = DataLoader(dataset=dataset, batch_size=batch_size)
|
|
|
|
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
assert trainer.num_training_batches == (limit_train_batches if limit_train_batches > 0.0 else float("inf"))
|
|
if limit_val_batches != 0.0:
|
|
assert trainer.num_val_batches[0] == limit_val_batches
|
|
else:
|
|
assert trainer.num_val_batches == []
|
|
assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0)
|
|
assert epoch_cb.train_batches_seen == limit_train_batches * epochs
|
|
assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0)
|
|
assert epoch_cb.val_batches_seen == limit_val_batches * epochs
|
|
|
|
trainer.test(model, dataloaders=test_dl)
|
|
assert trainer.num_test_batches[0] == limit_test_batches
|
|
assert epoch_cb.test_epoch_count == int(limit_test_batches > 0)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["limit_train_batches", "limit_val_batches", "limit_test_batches"],
|
|
[(0.0, 0.0, 0.0), (0, 0, 0.5), (1.0, 1.0, 1.0), (0.2, 0.4, 0.4)],
|
|
)
|
|
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
|
|
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent."""
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__multiple_mixed_length
|
|
model.test_dataloader = model.test_dataloader__multiple_mixed_length
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
model.test_step = model.test_step__multiple_dataloaders
|
|
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders
|
|
|
|
# train, multiple val and multiple test passed with percent_check
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=limit_train_batches,
|
|
limit_val_batches=limit_val_batches,
|
|
limit_test_batches=limit_test_batches,
|
|
)
|
|
trainer.fit(model)
|
|
if limit_train_batches != 0.0:
|
|
expected_train_batches = int(len(trainer.train_dataloader) * limit_train_batches)
|
|
expected_val_batches = [int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders]
|
|
assert trainer.num_training_batches == expected_train_batches
|
|
assert trainer.num_val_batches == expected_val_batches
|
|
else:
|
|
assert trainer.train_dataloader is None
|
|
|
|
trainer.test(model)
|
|
expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders]
|
|
assert trainer.num_test_batches == expected_test_batches
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["limit_train_batches", "limit_val_batches", "limit_test_batches"], [(0, 0, 0), (1, 2, 3), (1, 2, 1e50)]
|
|
)
|
|
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
|
|
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number."""
|
|
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__multiple_mixed_length
|
|
model.test_dataloader = model.test_dataloader__multiple_mixed_length
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
model.test_step = model.test_step__multiple_dataloaders
|
|
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders
|
|
|
|
# train, multiple val and multiple test passed with percent_check
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=limit_train_batches,
|
|
limit_val_batches=limit_val_batches,
|
|
limit_test_batches=limit_test_batches,
|
|
num_sanity_val_steps=0,
|
|
)
|
|
|
|
with patch.object(
|
|
trainer.fit_loop.epoch_loop.val_loop.epoch_loop,
|
|
"_evaluation_step",
|
|
wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop._evaluation_step,
|
|
) as mocked:
|
|
trainer.fit(model)
|
|
assert trainer.num_training_batches == (limit_train_batches if limit_train_batches != 0.0 else float("inf"))
|
|
if limit_train_batches != 0.0:
|
|
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
|
|
assert mocked.call_count == limit_val_batches * len(trainer.val_dataloaders)
|
|
else:
|
|
assert trainer.val_dataloaders is None
|
|
|
|
with patch.object(
|
|
trainer.test_loop.epoch_loop,
|
|
"_evaluation_step",
|
|
wraps=trainer.test_loop.epoch_loop._evaluation_step,
|
|
) as mocked:
|
|
trainer.test(model)
|
|
test_dataloader_lengths = [len(x) for x in model.test_dataloader()]
|
|
if limit_test_batches > 1e10:
|
|
# when the limit is greater than the number of test batches it should be the num in loaders
|
|
assert trainer.num_test_batches == test_dataloader_lengths
|
|
assert mocked.call_count == sum(test_dataloader_lengths)
|
|
else:
|
|
assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders)
|
|
assert mocked.call_count == limit_test_batches * len(trainer.test_dataloaders)
|
|
|
|
|
|
@pytest.mark.parametrize("fast_dev_run", [True, 1, 3, -1])
|
|
def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run):
|
|
"""Verify num_batches for train, val & test dataloaders passed with fast_dev_run."""
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__multiple_mixed_length
|
|
model.test_dataloader = model.test_dataloader__multiple_mixed_length
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
model.test_step = model.test_step__multiple_dataloaders
|
|
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders
|
|
|
|
trainer_options = dict(default_root_dir=tmpdir, max_epochs=2, fast_dev_run=fast_dev_run)
|
|
|
|
if fast_dev_run == -1:
|
|
with pytest.raises(MisconfigurationException, match="should be >= 0"):
|
|
Trainer(**trainer_options)
|
|
else:
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
# fast_dev_run is set to True when it is 1
|
|
if fast_dev_run == 1:
|
|
fast_dev_run = True
|
|
|
|
assert trainer.fast_dev_run is fast_dev_run
|
|
|
|
if fast_dev_run is True:
|
|
fast_dev_run = 1
|
|
|
|
assert trainer.limit_train_batches == fast_dev_run
|
|
assert trainer.limit_val_batches == fast_dev_run
|
|
assert trainer.limit_test_batches == fast_dev_run
|
|
assert trainer.num_sanity_val_steps == 0
|
|
assert trainer.max_epochs == 1
|
|
|
|
trainer.fit(model)
|
|
assert trainer.enable_validation
|
|
assert trainer.num_training_batches == fast_dev_run
|
|
assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders)
|
|
|
|
trainer.test(model)
|
|
assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders)
|
|
|
|
|
|
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
|
|
def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
|
|
"""Verify that dataloaders can be passed to fit."""
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
|
|
|
|
# fit model
|
|
trainer = Trainer(**trainer_options)
|
|
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
|
|
|
|
# fit model
|
|
trainer = Trainer(**trainer_options)
|
|
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
|
|
assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"
|
|
|
|
if ckpt_path == "specific":
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
trainer.test(dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path)
|
|
assert (
|
|
len(trainer.test_dataloaders) == 1
|
|
), f"`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}"
|
|
|
|
|
|
def test_train_inf_dataloader_error(tmpdir):
|
|
"""Test inf train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__infinite
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match="using an IterableDataset"):
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_val_inf_dataloader_error(tmpdir):
|
|
"""Test inf train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__infinite
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match="using an IterableDataset"):
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_test_inf_dataloader_error(tmpdir):
|
|
"""Test inf train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.test_dataloader = model.test_dataloader__infinite
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match="using an IterableDataset"):
|
|
trainer.test(model)
|
|
|
|
|
|
@pytest.mark.parametrize("check_interval", [50, 1.0])
|
|
def test_inf_train_dataloader(tmpdir, check_interval):
|
|
"""Test inf train data loader (e.g. IterableDataset)"""
|
|
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__infinite
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=check_interval)
|
|
trainer.fit(model)
|
|
# verify training completed
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
|
|
|
|
@pytest.mark.parametrize("check_interval", [1.0])
|
|
def test_inf_val_dataloader(tmpdir, check_interval):
|
|
"""Test inf val data loader (e.g. IterableDataset)"""
|
|
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__infinite
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=check_interval)
|
|
trainer.fit(model)
|
|
|
|
# verify training completed
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
|
|
|
|
def test_error_on_zero_len_dataloader(tmpdir):
|
|
"""Test that error is raised if a zero-length dataloader is defined."""
|
|
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__zero_length
|
|
|
|
# fit model
|
|
with pytest.raises(ValueError):
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=0.1,
|
|
limit_val_batches=0.1,
|
|
limit_test_batches=0.1,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
|
|
@RunIf(skip_windows=True)
|
|
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
|
|
@pytest.mark.parametrize("stage", ("train", "test", "val"))
|
|
@patch("pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count", return_value=4)
|
|
def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
|
|
"""Test that error is raised if dataloader with only a few workers is used."""
|
|
|
|
model = BoringModel()
|
|
|
|
train_dl = model.train_dataloader()
|
|
train_dl.num_workers = 0
|
|
|
|
val_dl = model.val_dataloader()
|
|
val_dl.num_workers = 0
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
|
|
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match=f'The dataloader, {stage}_dataloader{" 0" if stage != "train" else ""}, does not have many workers',
|
|
):
|
|
if stage == "test":
|
|
if ckpt_path in ("specific", "best"):
|
|
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path
|
|
trainer.test(model, dataloaders=train_dl, ckpt_path=ckpt_path)
|
|
else:
|
|
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
|
|
|
|
|
|
@RunIf(skip_windows=True)
|
|
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
|
|
@pytest.mark.parametrize("stage", ("train", "test", "val"))
|
|
@patch("pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count", return_value=4)
|
|
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
|
|
"""Test that error is raised if dataloader with only a few workers is used."""
|
|
|
|
model = EvalModelTemplate()
|
|
model.training_step = model.training_step__multiple_dataloaders
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
model.test_step = model.test_step__multiple_dataloaders
|
|
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders
|
|
|
|
val_dl = model.dataloader(train=False)
|
|
val_dl.num_workers = 0
|
|
|
|
train_dl = model.dataloader(train=False)
|
|
train_dl.num_workers = 0
|
|
|
|
train_multi_dl = {"a_b": [train_dl, train_dl], "c_d_e": [train_dl, train_dl, train_dl]}
|
|
val_multi_dl = [val_dl, val_dl]
|
|
test_multi_dl = [train_dl, train_dl]
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
|
|
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match=f'The dataloader, {stage}_dataloader{" 0" if stage != "train" else ""}, does not have many workers',
|
|
):
|
|
if stage == "test":
|
|
if ckpt_path in ("specific", "best"):
|
|
trainer.fit(model, train_dataloaders=train_multi_dl, val_dataloaders=val_multi_dl)
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path
|
|
trainer.test(model, dataloaders=test_multi_dl, ckpt_path=ckpt_path)
|
|
else:
|
|
trainer.fit(model, train_dataloaders=train_multi_dl, val_dataloaders=val_multi_dl)
|
|
|
|
|
|
class NumpyRandomDataset(Dataset):
|
|
# this datset uses numpy instead of torch to produce random numbers
|
|
size = 16
|
|
|
|
def __getitem__(self, index):
|
|
return numpy.random.randint(0, 100, 3)
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
def _user_worker_init_fn(_):
|
|
pass
|
|
|
|
|
|
@RunIf(max_torch="1.8.9")
|
|
def test_missing_worker_init_fn():
|
|
"""Test that naive worker seed initialization leads to undesired random state in subprocesses.
|
|
|
|
PyTorch 1.9+ does not have this issue.
|
|
"""
|
|
dataset = NumpyRandomDataset()
|
|
|
|
seed_everything(0)
|
|
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False)
|
|
batches0 = torch.cat(list(dataloader))
|
|
|
|
seed_everything(0)
|
|
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False)
|
|
batches1 = torch.cat(list(dataloader))
|
|
|
|
is_duplicated = len(torch.unique(batches1, dim=0)) < len(dataset)
|
|
is_deterministic = torch.eq(batches0, batches1).all()
|
|
|
|
# depending on the OS, we either have
|
|
# 1) the same seed in all worker proceses, producing duplicate samples / augmentations, or
|
|
# 2) different seeds in each worker process, but they are not derived from the seed of the main process
|
|
assert not is_deterministic or is_duplicated
|
|
|
|
|
|
def test_auto_add_worker_init_fn():
|
|
"""Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used."""
|
|
dataset = Mock()
|
|
dataloader = DataLoader(dataset)
|
|
|
|
# without pl.seed_everything()
|
|
_auto_add_worker_init_fn(dataloader, 0)
|
|
assert dataloader.worker_init_fn is None
|
|
|
|
# with forcefully avoiding it
|
|
seed_everything(0, workers=False)
|
|
_auto_add_worker_init_fn(dataloader, 0)
|
|
assert dataloader.worker_init_fn is None
|
|
|
|
# when user already has a worker_init_fn
|
|
user_function = _user_worker_init_fn
|
|
dataloader.worker_init_fn = user_function
|
|
_auto_add_worker_init_fn(dataloader, 0)
|
|
assert dataloader.worker_init_fn is user_function
|
|
dataloader.worker_init_fn = None
|
|
|
|
# main use case
|
|
seed_everything(0, workers=True)
|
|
_auto_add_worker_init_fn(dataloader, 0)
|
|
assert dataloader.worker_init_fn is not None
|
|
|
|
|
|
class MultiProcessModel(BoringModel):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.batches_seen = []
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
self.batches_seen.append(batch)
|
|
|
|
def training_epoch_end(self, outputs):
|
|
world_size = 2
|
|
num_samples = NumpyRandomDataset.size
|
|
all_batches = torch.cat(self.batches_seen)
|
|
all_batches = self.all_gather(all_batches)
|
|
assert all_batches.shape[0] == world_size
|
|
all_batches = all_batches.view(-1, 3)
|
|
assert len(torch.unique(all_batches, dim=0)) == num_samples
|
|
|
|
|
|
@RunIf(min_gpus=2)
|
|
def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch):
|
|
"""Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training."""
|
|
dataset = NumpyRandomDataset()
|
|
num_workers = 2
|
|
batch_size = 2
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
|
|
seed_everything(0, workers=True)
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=2, strategy="ddp_spawn")
|
|
model = MultiProcessModel()
|
|
model.val_dataloader = None
|
|
trainer.fit(model, train_dataloaders=dataloader)
|
|
|
|
|
|
def test_warning_with_small_dataloader_and_logging_interval(tmpdir):
|
|
"""Test that a warning message is shown if the dataloader length is too short for the chosen logging
|
|
interval."""
|
|
model = BoringModel()
|
|
dataloader = DataLoader(RandomDataset(32, length=10))
|
|
model.train_dataloader = lambda: dataloader
|
|
|
|
with pytest.warns(UserWarning, match=r"The number of training samples \(10\) is smaller than the logging interval"):
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=11)
|
|
trainer.fit(model)
|
|
|
|
with pytest.warns(UserWarning, match=r"The number of training samples \(1\) is smaller than the logging interval"):
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1)
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_warning_with_iterable_dataset_and_len(tmpdir):
|
|
"""Tests that a warning message is shown when an IterableDataset defines `__len__`."""
|
|
model = BoringModel()
|
|
original_dataset = model.train_dataloader().dataset
|
|
|
|
class IterableWithoutLen(IterableDataset):
|
|
def __iter__(self):
|
|
return iter(original_dataset)
|
|
|
|
class IterableWithLen(IterableWithoutLen):
|
|
def __len__(self):
|
|
return len(original_dataset)
|
|
|
|
# with __len__ defined
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
|
|
dataloader = DataLoader(IterableWithLen(), batch_size=16)
|
|
assert has_len_all_ranks(dataloader, trainer.training_type_plugin, model)
|
|
assert has_iterable_dataset(dataloader)
|
|
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
|
|
trainer.validate(model, dataloaders=[dataloader])
|
|
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
|
|
trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=[dataloader])
|
|
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
|
|
trainer.test(model, dataloaders=[dataloader])
|
|
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
|
|
trainer.predict(model, dataloaders=[dataloader])
|
|
|
|
# without __len__ defined
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
|
|
dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
|
|
assert not has_len_all_ranks(dataloader, trainer.training_type_plugin, model)
|
|
assert has_iterable_dataset(dataloader)
|
|
trainer.validate(model, dataloaders=dataloader)
|
|
trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=[dataloader])
|
|
trainer.test(model, dataloaders=dataloader)
|
|
trainer.predict(model, dataloaders=dataloader)
|
|
|
|
|
|
@pytest.mark.parametrize("yield_at_all", (False, True))
|
|
def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all):
|
|
"""Test that the training loop skips execution if the iterator is empty from the start."""
|
|
|
|
class TestDataset(IterableDataset):
|
|
def __init__(self, gen):
|
|
self.gen = gen
|
|
|
|
def __iter__(self):
|
|
return iter(self.gen())
|
|
|
|
class TestModel(BoringModel):
|
|
def gen(self):
|
|
# produce data in epoch 0, no data otherwise
|
|
if yield_at_all and self.current_epoch == 0:
|
|
yield torch.rand(32)
|
|
yield torch.rand(32)
|
|
yield torch.rand(32)
|
|
|
|
model = TestModel()
|
|
train_dataloader = DataLoader(TestDataset(model.gen), batch_size=2)
|
|
trainer = Trainer(
|
|
default_root_dir=os.getcwd(),
|
|
max_epochs=2,
|
|
enable_model_summary=False,
|
|
)
|
|
trainer.fit(model, train_dataloaders=train_dataloader)
|
|
assert trainer.global_step == 2 * yield_at_all
|
|
# we expect the second epoch to be skipped
|
|
assert trainer.current_epoch == 1
|
|
|
|
|
|
class DistribSamplerCallback(Callback):
|
|
def __init__(self, expected_seeds=(0, 0, 0)):
|
|
self.expected_seed = expected_seeds
|
|
|
|
def on_train_start(self, trainer, pl_module):
|
|
train_sampler = trainer.train_dataloader.sampler
|
|
assert isinstance(train_sampler, DistributedSampler)
|
|
assert train_sampler.shuffle
|
|
assert train_sampler.seed == self.expected_seed[0]
|
|
|
|
def on_validation_start(self, trainer, pl_module):
|
|
val_sampler = trainer.val_dataloaders[0].sampler
|
|
assert isinstance(val_sampler, DistributedSampler)
|
|
assert not val_sampler.shuffle
|
|
assert val_sampler.seed == self.expected_seed[1]
|
|
|
|
def on_test_start(self, trainer, pl_module):
|
|
test_sampler = trainer.test_dataloaders[0].sampler
|
|
assert isinstance(test_sampler, DistributedSampler)
|
|
assert not test_sampler.shuffle
|
|
assert test_sampler.seed == self.expected_seed[2]
|
|
|
|
|
|
@RunIf(min_gpus=2, skip_windows=True)
|
|
def test_dataloader_distributed_sampler(tmpdir):
|
|
"""Test DistributedSampler and it's arguments for DDP backend."""
|
|
seed_everything(123)
|
|
model = EvalModelTemplate()
|
|
trainer = Trainer(
|
|
gpus=[0, 1],
|
|
num_nodes=1,
|
|
strategy="ddp_spawn",
|
|
default_root_dir=tmpdir,
|
|
max_steps=1,
|
|
callbacks=[DistribSamplerCallback(expected_seeds=(123, 123, 123))],
|
|
)
|
|
trainer.fit(model)
|
|
trainer.test(model)
|
|
|
|
|
|
class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):
|
|
def train_dataloader(self):
|
|
dataloader = super().train_dataloader()
|
|
dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True, seed=11)
|
|
return DataLoader(
|
|
dataloader.dataset, batch_size=self.batch_size, drop_last=False, sampler=dist_sampler, shuffle=False
|
|
)
|
|
|
|
|
|
@RunIf(min_gpus=2, skip_windows=True)
|
|
def test_dataloader_distributed_sampler_already_attached(tmpdir):
|
|
"""Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on
|
|
dataloader."""
|
|
seed_everything(123)
|
|
model = ModelWithDataLoaderDistributedSampler()
|
|
trainer = Trainer(
|
|
gpus=[0, 1],
|
|
num_nodes=1,
|
|
strategy="ddp_spawn",
|
|
default_root_dir=tmpdir,
|
|
max_steps=100,
|
|
callbacks=[DistribSamplerCallback(expected_seeds=(11, 123, 0))],
|
|
replace_sampler_ddp=True,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.state.finished, "DDP Training failed"
|
|
|
|
|
|
@RunIf(min_gpus=3)
|
|
def test_batch_size_smaller_than_num_gpus(tmpdir):
|
|
# we need at least 3 gpus for this test
|
|
num_gpus = 3
|
|
batch_size = 3
|
|
|
|
class CurrentTestModel(EvalModelTemplate):
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
# batch norm doesn't work with batch size 1, we replace it
|
|
self.c_d1_bn = torch.nn.ReLU()
|
|
|
|
def training_step(self, *args, **kwargs):
|
|
output = super().training_step(*args, **kwargs)
|
|
loss = output["loss"]
|
|
# we make sure to add some metrics to the output dict,
|
|
# this is essential for this test
|
|
output["progress_bar"] = {"train_loss": loss}
|
|
return output
|
|
|
|
def train_dataloader(self):
|
|
dataloader = super().train_dataloader()
|
|
# construct a dataset with a size that is not divisible by num_gpus
|
|
# therefore the last batch will have a size < num_gpus
|
|
size = num_gpus * batch_size + (num_gpus - 1)
|
|
dataset = Subset(dataloader.dataset, range(size))
|
|
dataloader = DataLoader(dataset, batch_size=self.batch_size, drop_last=False)
|
|
return dataloader
|
|
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
hparams["batch_size"] = batch_size
|
|
model = CurrentTestModel(**hparams)
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0.1, limit_val_batches=0, gpus=num_gpus
|
|
)
|
|
|
|
# we expect the reduction for the metrics also to happen on the last batch
|
|
# where we will get fewer metrics than gpus
|
|
trainer.fit(model)
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["multiple_trainloader_mode", "num_training_batches"],
|
|
[("min_size", 5), ("max_size_cycle", 10)],
|
|
)
|
|
def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches):
|
|
"""Integration test for multple train loaders."""
|
|
model = EvalModelTemplate()
|
|
|
|
model.train_dataloader = model.train_dataloader__multiple_mapping
|
|
# todo: add also `train_dataloader__multiple_sequence`
|
|
model.training_step = model.training_step__multiple_dataloaders
|
|
|
|
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, multiple_trainloader_mode=multiple_trainloader_mode)
|
|
trainer.fit(model)
|
|
# verify the num_training_batches according to the multiple_trainloader_mode
|
|
assert num_training_batches == trainer.num_training_batches
|
|
|
|
|
|
@pytest.mark.parametrize("check_interval", [1.0])
|
|
def test_val_dataloader_not_implemented_error(tmpdir, check_interval):
|
|
"""Test not_implemented_error data loader (e.g. IterableDataset)"""
|
|
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__not_implemented_error
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=check_interval)
|
|
trainer.fit(model)
|
|
# verify training completed
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
|
|
|
|
@pytest.mark.parametrize("check_interval", [50, 1.0])
|
|
def test_train_dataloader_not_implemented_error(tmpdir, check_interval):
|
|
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
|
|
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__not_implemented_error
|
|
model.val_dataloader = model.val_dataloader__not_implemented_error
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=check_interval)
|
|
trainer.fit(model)
|
|
# verify training completed
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
|
|
|
|
def test_train_dataloader_not_implemented_error_failed(tmpdir):
|
|
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__not_implemented_error
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match="using an IterableDataset"):
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_val_dataloader_not_implemented_error_failed(tmpdir):
|
|
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__not_implemented_error
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match="using an IterableDataset"):
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_test_dataloader_not_implemented_error_failed(tmpdir):
|
|
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.test_dataloader = model.test_dataloader__not_implemented_error
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match="using an IterableDataset"):
|
|
trainer.test(model)
|
|
|
|
|
|
def test_dataloaders_load_only_once(tmpdir):
|
|
model = BoringModel()
|
|
tracker = Mock()
|
|
|
|
model.train_dataloader = Mock(wraps=model.train_dataloader)
|
|
model.val_dataloader = Mock(wraps=model.val_dataloader)
|
|
model.test_dataloader = Mock(wraps=model.test_dataloader)
|
|
|
|
tracker.attach_mock(model.train_dataloader, "train_dataloader")
|
|
tracker.attach_mock(model.val_dataloader, "val_dataloader")
|
|
tracker.attach_mock(model.test_dataloader, "test_dataloader")
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, max_epochs=3)
|
|
trainer.fit(model)
|
|
|
|
model.train_dataloader.assert_called_once()
|
|
model.val_dataloader.assert_called_once()
|
|
model.test_dataloader.assert_not_called()
|
|
|
|
assert tracker.mock_calls == [call.val_dataloader(), call.train_dataloader()]
|
|
|
|
|
|
def test_dataloaders_load_only_once_val_interval(tmpdir):
|
|
model = BoringModel()
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=10,
|
|
limit_val_batches=10,
|
|
val_check_interval=0.3,
|
|
reload_dataloaders_every_n_epochs=1,
|
|
max_epochs=3,
|
|
)
|
|
|
|
tracker = Mock()
|
|
model.train_dataloader = Mock(wraps=model.train_dataloader)
|
|
model.val_dataloader = Mock(wraps=model.val_dataloader)
|
|
model.test_dataloader = Mock(wraps=model.test_dataloader)
|
|
|
|
tracker.attach_mock(model.train_dataloader, "train_dataloader")
|
|
tracker.attach_mock(model.val_dataloader, "val_dataloader")
|
|
tracker.attach_mock(model.test_dataloader, "test_dataloader")
|
|
|
|
trainer.fit(model)
|
|
trainer.test(model)
|
|
|
|
# verify the sequence
|
|
expected_sequence = [
|
|
call.val_dataloader(),
|
|
call.train_dataloader(),
|
|
call.val_dataloader(),
|
|
call.val_dataloader(),
|
|
call.val_dataloader(),
|
|
call.train_dataloader(),
|
|
call.val_dataloader(),
|
|
call.val_dataloader(),
|
|
call.val_dataloader(),
|
|
call.train_dataloader(),
|
|
call.val_dataloader(),
|
|
call.val_dataloader(),
|
|
call.val_dataloader(),
|
|
call.test_dataloader(),
|
|
]
|
|
assert tracker.mock_calls == expected_sequence
|
|
|
|
|
|
def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
|
|
model = BoringModel()
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3
|
|
)
|
|
|
|
tracker = Mock()
|
|
model.train_dataloader = Mock(wraps=model.train_dataloader)
|
|
model.val_dataloader = Mock(wraps=model.val_dataloader)
|
|
model.test_dataloader = Mock(wraps=model.test_dataloader)
|
|
|
|
tracker.attach_mock(model.train_dataloader, "train_dataloader")
|
|
tracker.attach_mock(model.val_dataloader, "val_dataloader")
|
|
tracker.attach_mock(model.test_dataloader, "test_dataloader")
|
|
|
|
trainer.fit(model)
|
|
|
|
# verify the sequence
|
|
expected_sequence = [call.train_dataloader(), call.val_dataloader()]
|
|
assert tracker.mock_calls == expected_sequence
|
|
|
|
|
|
@pytest.mark.parametrize("n", [1, 2])
|
|
def test_dataloaders_load_every_n_epochs(tmpdir, n):
|
|
model = BoringModel()
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=0.3,
|
|
limit_val_batches=0.3,
|
|
reload_dataloaders_every_n_epochs=n,
|
|
max_epochs=3,
|
|
)
|
|
|
|
tracker = Mock()
|
|
model.train_dataloader = Mock(wraps=model.train_dataloader)
|
|
model.val_dataloader = Mock(wraps=model.val_dataloader)
|
|
model.test_dataloader = Mock(wraps=model.test_dataloader)
|
|
|
|
tracker.attach_mock(model.train_dataloader, "train_dataloader")
|
|
tracker.attach_mock(model.val_dataloader, "val_dataloader")
|
|
tracker.attach_mock(model.test_dataloader, "test_dataloader")
|
|
|
|
trainer.fit(model)
|
|
trainer.test(model)
|
|
|
|
# verify the sequence
|
|
expected_sequence = [call.val_dataloader()]
|
|
if n == 1:
|
|
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 3
|
|
elif n == 2:
|
|
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2
|
|
expected_sequence += [call.test_dataloader()]
|
|
assert tracker.mock_calls == expected_sequence
|
|
|
|
|
|
@pytest.mark.parametrize("n", ["test", -1])
|
|
def test_dataloaders_load_every_n_epochs_exception(tmpdir, n):
|
|
|
|
with pytest.raises(MisconfigurationException, match="should be an int >"):
|
|
Trainer(default_root_dir=tmpdir, reload_dataloaders_every_n_epochs=n)
|
|
|
|
|
|
def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir):
|
|
class TestModel(BoringModel):
|
|
def validation_step(self, batch, batch_idx):
|
|
self.log("dummy_val", 5.0)
|
|
return super().validation_step(batch, batch_idx)
|
|
|
|
model = TestModel()
|
|
|
|
# This callback tests that the evaluation metrics are available by the time we run checkpointing
|
|
checkpoint_callback = ModelCheckpoint(monitor="dummy_val", save_top_k=1)
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=0.3,
|
|
limit_val_batches=0.3,
|
|
num_sanity_val_steps=0,
|
|
reload_dataloaders_every_n_epochs=1,
|
|
max_epochs=3,
|
|
callbacks=[checkpoint_callback],
|
|
)
|
|
|
|
tracker = Mock()
|
|
model.train_dataloader = Mock(wraps=model.train_dataloader)
|
|
model.val_dataloader = Mock(wraps=model.val_dataloader)
|
|
model.test_dataloader = Mock(wraps=model.test_dataloader)
|
|
|
|
tracker.attach_mock(model.train_dataloader, "train_dataloader")
|
|
tracker.attach_mock(model.val_dataloader, "val_dataloader")
|
|
tracker.attach_mock(model.test_dataloader, "test_dataloader")
|
|
|
|
trainer.fit(model)
|
|
trainer.test(model)
|
|
|
|
expected_calls = [
|
|
call.train_dataloader(),
|
|
call.val_dataloader(),
|
|
# This has subsequent calls to val_dataloader
|
|
# because the training loop runs the evaluation loop,
|
|
# which reloads the val dataloader again.
|
|
# We cannot yet rely on trainer.current_epoch=0 to skip reloading
|
|
# the val dataloader on the first epoch because this only tracks the training epoch
|
|
# meaning multiple passes through the validation data within a single training epoch
|
|
# would not have the dataloader reloaded.
|
|
# This breaks the assumption behind reload_dataloaders_every_n_epochs=1
|
|
call.val_dataloader(),
|
|
call.train_dataloader(),
|
|
call.val_dataloader(),
|
|
call.train_dataloader(),
|
|
call.val_dataloader(),
|
|
call.test_dataloader(),
|
|
]
|
|
assert tracker.mock_calls == expected_calls
|
|
|
|
|
|
def test_dataloaders_load_only_once_passed_loaders(tmpdir):
|
|
model = BoringModel()
|
|
train_dataloader = model.train_dataloader()
|
|
val_dataloader = model.val_dataloader()
|
|
test_dataloader = model.test_dataloader()
|
|
|
|
# delete dataloader methods on the model
|
|
model.train_dataloader = None
|
|
model.val_dataloader = None
|
|
model.test_dataloader = None
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, max_epochs=3)
|
|
|
|
trainer.reset_train_dataloader = Mock(wraps=trainer.reset_train_dataloader)
|
|
trainer.reset_val_dataloader = Mock(wraps=trainer.reset_val_dataloader)
|
|
trainer.reset_test_dataloader = Mock(wraps=trainer.reset_test_dataloader)
|
|
|
|
tracker = Mock()
|
|
tracker.attach_mock(trainer.reset_train_dataloader, "reset_train_dataloader")
|
|
tracker.attach_mock(trainer.reset_val_dataloader, "reset_val_dataloader")
|
|
tracker.attach_mock(trainer.reset_test_dataloader, "reset_test_dataloader")
|
|
|
|
trainer.fit(model, train_dataloader, val_dataloader)
|
|
trainer.test(model, dataloaders=test_dataloader)
|
|
|
|
trainer.reset_train_dataloader.assert_called_once()
|
|
trainer.reset_val_dataloader.assert_called_once()
|
|
trainer.reset_test_dataloader.assert_called_once()
|
|
|
|
assert tracker.mock_calls == [
|
|
call.reset_val_dataloader(),
|
|
call.reset_train_dataloader(model=model),
|
|
call.reset_test_dataloader(),
|
|
]
|
|
|
|
|
|
def test_dataloaders_reset_and_attach(tmpdir):
|
|
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching
|
|
the new one."""
|
|
# the assertions compare the datasets and not dataloaders since we patch and replace the samplers
|
|
dataloader_0 = DataLoader(dataset=RandomDataset(32, 64))
|
|
dataloader_1 = DataLoader(dataset=RandomDataset(32, 64))
|
|
dataloader_2 = DataLoader(dataset=RandomDataset(32, 64))
|
|
dataloader_3 = DataLoader(dataset=RandomDataset(32, 64))
|
|
model = BoringModel()
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_steps=1,
|
|
limit_val_batches=1,
|
|
limit_test_batches=1,
|
|
limit_predict_batches=1,
|
|
)
|
|
|
|
# 1st fit
|
|
trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1)
|
|
assert trainer.train_dataloader.loaders.dataset is dataloader_0.dataset
|
|
assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset
|
|
# 2nd fit
|
|
trainer.fit_loop.max_steps += 1
|
|
trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
|
|
assert trainer.train_dataloader.loaders.dataset is dataloader_2.dataset
|
|
assert trainer.val_dataloaders[0].dataset is dataloader_3.dataset
|
|
|
|
# 1st validate
|
|
trainer.validate(model, dataloaders=dataloader_0)
|
|
assert trainer.val_dataloaders[0].dataset is dataloader_0.dataset
|
|
# 2nd validate
|
|
trainer.validate(model, dataloaders=dataloader_1)
|
|
assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset
|
|
|
|
# 1st test
|
|
trainer.test(model, dataloaders=dataloader_0)
|
|
assert trainer.test_dataloaders[0].dataset is dataloader_0.dataset
|
|
# 2nd test
|
|
trainer.test(model, dataloaders=dataloader_1)
|
|
assert trainer.test_dataloaders[0].dataset is dataloader_1.dataset
|
|
|
|
# 1st predict
|
|
trainer.predict(model, dataloaders=dataloader_0)
|
|
assert trainer.predict_dataloaders[0].dataset is dataloader_0.dataset
|
|
# 2nd predict
|
|
trainer.predict(model, dataloaders=dataloader_1)
|
|
assert trainer.predict_dataloaders[0].dataset is dataloader_1.dataset
|
|
|
|
|
|
@pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"])
|
|
def test_correct_dataloader_idx_in_hooks(tmpdir, multiple_trainloader_mode):
|
|
"""Check the correct dataloader_idx inside hooks."""
|
|
|
|
class CustomBoringModel(BoringModel):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.val_call_count = 0
|
|
self.test_call_count = 0
|
|
|
|
def assert_dataloader_idx_hook(self, dataloader_idx):
|
|
if self.trainer.training:
|
|
assert dataloader_idx == 0
|
|
elif self.trainer.validating:
|
|
assert dataloader_idx == (0 if self.val_call_count <= 5 else 1)
|
|
elif self.trainer.testing:
|
|
assert dataloader_idx == (0 if self.test_call_count <= 5 else 1)
|
|
|
|
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
|
self.assert_dataloader_idx_hook(dataloader_idx)
|
|
return super().transfer_batch_to_device(batch, device, dataloader_idx)
|
|
|
|
def on_before_batch_transfer(self, batch, dataloader_idx):
|
|
# incrementing here since this is the first hook called at each step
|
|
if self.trainer.validating:
|
|
self.val_call_count += 1
|
|
elif self.trainer.testing:
|
|
self.test_call_count += 1
|
|
|
|
self.assert_dataloader_idx_hook(dataloader_idx)
|
|
return super().on_before_batch_transfer(batch, dataloader_idx)
|
|
|
|
def on_after_batch_transfer(self, batch, dataloader_idx):
|
|
self.assert_dataloader_idx_hook(dataloader_idx)
|
|
return super().on_after_batch_transfer(batch, dataloader_idx)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
return super().training_step(batch["a"], batch_idx)
|
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx):
|
|
self.assert_dataloader_idx_hook(dataloader_idx)
|
|
out = super().validation_step(batch, batch_idx)
|
|
loss = out.pop("x")
|
|
out[f"val_loss_{dataloader_idx}"] = loss
|
|
return out
|
|
|
|
def test_step(self, batch, batch_idx, dataloader_idx):
|
|
self.assert_dataloader_idx_hook(dataloader_idx)
|
|
out = super().test_step(batch, batch_idx)
|
|
loss = out.pop("y")
|
|
out[f"test_loss_{dataloader_idx}"] = loss
|
|
return out
|
|
|
|
def predict(self, batch, batch_idx, dataloader_idx):
|
|
self.assert_dataloader_idx_hook(dataloader_idx)
|
|
return super().predict(batch, batch_idx, dataloader_idx)
|
|
|
|
def assert_epoch_end_outputs(self, outputs, mode):
|
|
assert len(outputs) == 2
|
|
assert all(f"{mode}_loss_0" in x for x in outputs[0])
|
|
assert all(f"{mode}_loss_1" in x for x in outputs[1])
|
|
|
|
def validation_epoch_end(self, outputs):
|
|
self.assert_epoch_end_outputs(outputs, mode="val")
|
|
|
|
def test_epoch_end(self, outputs):
|
|
self.assert_epoch_end_outputs(outputs, mode="test")
|
|
|
|
def train_dataloader(self):
|
|
return {"a": DataLoader(RandomDataset(32, 64)), "b": DataLoader(RandomDataset(32, 64))}
|
|
|
|
def val_dataloader(self):
|
|
return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
|
|
|
|
def test_dataloader(self):
|
|
return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
|
|
|
|
def predict_dataloader(self):
|
|
return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
|
|
|
|
model = CustomBoringModel()
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=5, multiple_trainloader_mode=multiple_trainloader_mode)
|
|
|
|
trainer.fit(model)
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
trainer.test(model)
|
|
|
|
preds = trainer.predict(model)
|
|
assert len(preds) == 2
|
|
assert all(len(x) == 5 for x in preds)
|
|
|
|
|
|
def test_request_dataloader(tmpdir):
|
|
"""This test asserts dataloader can be wrapped."""
|
|
|
|
class DataLoaderWrapper:
|
|
def __init__(self, loader):
|
|
self.loader = loader
|
|
self._iter = iter(self.loader)
|
|
|
|
def __iter__(self):
|
|
self._iter = iter(self.loader)
|
|
return self._iter
|
|
|
|
def __next__(self):
|
|
return next(self._iter)
|
|
|
|
class TestModel(BoringModel):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.on_train_batch_start_called = False
|
|
self.on_val_batch_start_called = False
|
|
|
|
def train_dataloader(self):
|
|
loader = super().train_dataloader()
|
|
return DataLoaderWrapper(loader)
|
|
|
|
def on_train_batch_start(self, batch, batch_idx: int) -> None:
|
|
assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper)
|
|
self.on_train_batch_start_called = True
|
|
|
|
def val_dataloader(self):
|
|
loader = super().val_dataloader()
|
|
return DataLoaderWrapper(loader)
|
|
|
|
def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None:
|
|
assert isinstance(self.trainer.val_dataloaders[0], DataLoaderWrapper)
|
|
self.on_val_batch_start_called = True
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1
|
|
)
|
|
model = TestModel()
|
|
trainer.fit(model)
|
|
trainer.test(model)
|
|
assert model.on_train_batch_start_called
|
|
assert model.on_val_batch_start_called
|
|
|
|
|
|
@pytest.mark.parametrize("num_loaders", [1, 2])
|
|
def test_multiple_dataloaders_with_random_sampler_overfit_batches(num_loaders, tmpdir):
|
|
class TestModel(BoringModel):
|
|
def training_step(self, batch, batch_idx):
|
|
for idx in range(num_loaders):
|
|
assert isinstance(self.trainer.train_dataloader.loaders[idx].loader.sampler, SequentialSampler)
|
|
return super().training_step(batch[0], batch_idx)
|
|
|
|
def _create_dataloader(self):
|
|
ds = RandomDataset(32, 64)
|
|
return DataLoader(ds, sampler=RandomSampler(ds))
|
|
|
|
def train_dataloader(self):
|
|
return [self._create_dataloader() for _ in range(num_loaders)]
|
|
|
|
validation_step = None
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, overfit_batches=1.0, max_epochs=1)
|
|
trainer.fit(TestModel())
|