lightning/tests/tests_pytorch/trainer/test_dataloaders.py

1379 lines
53 KiB
Python
Raw Normal View History

2020-10-13 11:18:07 +00:00
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import call, Mock, patch
2020-04-20 08:04:37 +00:00
import numpy
import pytest
import torch
from torch.utils.data import RandomSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler
from lightning_lite.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset
from pytorch_lightning import Callback, seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import (
BoringModel,
RandomDataset,
RandomIterableDataset,
RandomIterableDatasetWithLen,
)
from pytorch_lightning.loggers import CSVLogger
2021-12-17 18:10:52 +00:00
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.data import has_len_all_ranks
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
from tests_pytorch.helpers.runif import RunIf
2021-12-17 18:10:52 +00:00
class MultiValDataLoaderBoringModel(BoringModel):
def val_dataloader(self):
return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64), batch_size=8)]
def validation_step(self, batch, batch_idx, dataloader_idx):
return super().validation_step(batch, batch_idx)
def validation_epoch_end(self, *args, **kwargs):
pass
class MultiTestDataLoaderBoringModel(BoringModel):
def test_dataloader(self):
return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64), batch_size=8)]
def test_step(self, batch, batch_idx, dataloader_idx):
return super().test_step(batch, batch_idx)
def test_epoch_end(self, *args, **kwargs):
pass
class MultiEvalDataLoaderModel(MultiValDataLoaderBoringModel, MultiTestDataLoaderBoringModel):
pass
def test_fit_train_loader_only(tmpdir):
2021-12-17 18:10:52 +00:00
model = BoringModel()
train_dataloader = model.train_dataloader()
model.train_dataloader = None
model.val_dataloader = None
model.test_dataloader = None
model.validation_step = None
model.validation_epoch_end = None
model.test_step = None
model.test_epoch_end = None
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, train_dataloaders=train_dataloader)
def test_fit_val_loader_only(tmpdir):
2021-12-17 18:10:52 +00:00
model = BoringModel()
train_dataloader = model.train_dataloader()
val_dataloader = model.val_dataloader()
model.train_dataloader = None
model.val_dataloader = None
model.test_dataloader = None
model.test_step = None
model.test_epoch_end = None
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
@pytest.mark.parametrize("dataloader_options", [dict(val_check_interval=10000)])
def test_dataloader_config_errors_runtime(tmpdir, dataloader_options):
2021-12-17 18:10:52 +00:00
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, **dataloader_options)
2021-12-17 18:10:52 +00:00
with pytest.raises(ValueError, match="less than or equal to the number of the training batches"):
trainer.fit(model)
2021-02-06 15:06:17 +00:00
@pytest.mark.parametrize(
"dataloader_options",
[
2021-02-06 15:06:17 +00:00
dict(limit_train_batches=-0.1),
dict(limit_train_batches=1.2),
dict(limit_val_batches=-0.1),
dict(limit_val_batches=1.2),
dict(limit_test_batches=-0.1),
dict(limit_test_batches=1.2),
dict(val_check_interval=-0.1),
dict(val_check_interval=1.2),
dict(overfit_batches=-0.1),
dict(overfit_batches=1.2),
],
2021-02-06 15:06:17 +00:00
)
def test_dataloader_config_errors_init(tmpdir, dataloader_options):
with pytest.raises(MisconfigurationException, match="passed invalid value"):
Trainer(default_root_dir=tmpdir, max_epochs=1, **dataloader_options)
def test_multiple_val_dataloader(tmpdir):
"""Verify multiple val_dataloader."""
2021-12-17 18:10:52 +00:00
model = MultiValDataLoaderBoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.3, limit_train_batches=1.0)
trainer.fit(model)
# verify training completed
assert trainer.state.finished, f"Training failed with {trainer.state}"
# verify there are 2 val loaders
assert len(trainer.val_dataloaders) == 2, "Multiple val_dataloaders not initiated properly"
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
def test_multiple_eval_dataloader(tmpdir, ckpt_path):
"""Verify multiple evaluation dataloaders."""
2021-12-17 18:10:52 +00:00
model = MultiEvalDataLoaderModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=10, limit_train_batches=100)
trainer.fit(model)
2021-12-17 18:10:52 +00:00
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path
trainer.validate(ckpt_path=ckpt_path, verbose=False)
# verify there are 2 loaders
assert len(trainer.val_dataloaders) == 2
trainer.test(ckpt_path=ckpt_path, verbose=False)
assert len(trainer.test_dataloaders) == 2
def test_train_dataloader_passed_to_fit(tmpdir):
"""Verify that train dataloader can be passed to fit."""
# only train passed to fit
2021-12-17 18:10:52 +00:00
model = BoringModel()
train_loader = model.train_dataloader()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
fit_options = dict(train_dataloaders=train_loader)
trainer.fit(model, **fit_options)
2021-12-17 18:10:52 +00:00
assert trainer.num_training_batches == 2
assert trainer.train_dataloader.loaders == train_loader
assert trainer.state.finished, f"Training failed with {trainer.state}"
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
@pytest.mark.parametrize("n", (1, 2))
def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
"""Verify that dataloaders can be passed."""
2021-12-17 18:10:52 +00:00
train_dataloaders = DataLoader(RandomDataset(32, 64))
if n == 1:
2021-12-17 18:10:52 +00:00
model = BoringModel()
eval_dataloaders = DataLoader(RandomDataset(32, 64))
else:
2021-12-17 18:10:52 +00:00
model = MultiEvalDataLoaderModel()
eval_dataloaders = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
# multiple val dataloaders passed to fit
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
2021-12-17 18:10:52 +00:00
trainer.fit(model, train_dataloaders=train_dataloaders, val_dataloaders=eval_dataloaders)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert len(trainer.val_dataloaders) == n
if ckpt_path == "specific":
ckpt_path = trainer.checkpoint_callback.best_model_path
2021-12-17 18:10:52 +00:00
trainer.test(dataloaders=eval_dataloaders, ckpt_path=ckpt_path)
assert len(trainer.test_dataloaders) == n
2021-12-17 18:10:52 +00:00
trainer.validate(dataloaders=eval_dataloaders, ckpt_path=ckpt_path)
assert len(trainer.val_dataloaders) == n
class DummyModel(BoringModel):
def training_step(self, batch, batch_idx):
self.log("loss", self.global_step)
return super().training_step(batch, batch_idx)
def validation_epoch_end(self, outputs):
self.log("val_log", self.current_epoch)
class Counter(Callback):
def __init__(self):
super().__init__()
self.train_epoch_count = 0
self.val_epoch_count = 0
self.test_epoch_count = 0
self.train_batches_seen = 0
self.val_batches_seen = 0
self.test_batches_seen = 0
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
self.train_batches_seen += 1
def on_train_epoch_start(self, trainer, pl_module):
self.train_epoch_count += 1
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.val_batches_seen += 1
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.test_batches_seen += 1
def on_validation_epoch_start(self, trainer, pl_module):
self.val_epoch_count += 1
def on_test_epoch_start(self, trainer, pl_module):
self.test_epoch_count += 1
def test_inf_dataloaders_with_limit_percent_batches(tmpdir):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent."""
epoch_cb = Counter()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=1,
callbacks=[epoch_cb],
limit_train_batches=1.0,
limit_val_batches=1.0,
limit_test_batches=1.0,
)
model = DummyModel()
batch_size = 8
train_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size)
val_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size)
test_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size)
num_batches = 128 / batch_size
for dl in (train_dl, val_dl, test_dl):
if has_len_all_ranks(dl, trainer.strategy, model):
assert len(dl) == num_batches
else:
assert sum(1 for _ in dl) == num_batches
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.num_training_batches == float("inf")
assert epoch_cb.train_epoch_count == 1
assert trainer.num_val_batches[0] == float("inf")
assert epoch_cb.val_epoch_count == 1
trainer.test(model, dataloaders=test_dl)
assert trainer.num_test_batches[0] == float("inf")
assert epoch_cb.test_epoch_count == 1
@pytest.mark.parametrize(
["dataset", "limit_train_batches"],
[
(RandomDataset(32, 128), 10),
(RandomIterableDataset(32, 128), 10),
(RandomIterableDatasetWithLen(32, 128), 10),
],
)
def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number."""
epoch_cb = Counter()
max_epochs = 2
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=max_epochs,
callbacks=[epoch_cb],
limit_train_batches=limit_train_batches,
)
model = DummyModel()
batch_size = 8
train_dl = DataLoader(dataset=dataset, batch_size=batch_size)
val_dl = DataLoader(dataset=dataset, batch_size=batch_size)
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.num_training_batches == limit_train_batches
assert epoch_cb.train_epoch_count == max_epochs
assert epoch_cb.train_batches_seen == limit_train_batches * max_epochs
@pytest.mark.parametrize(
"dataset",
[
RandomDataset(32, 128),
RandomIterableDataset(32, 128),
RandomIterableDatasetWithLen(32, 128),
],
)
def test_dataloaders_with_limit_val_batches(tmpdir, dataset):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number."""
epoch_cb = Counter()
callbacks = [epoch_cb]
enable_checkpointing = False
max_epochs = 2
limit_val_batches = 10
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=max_epochs,
callbacks=callbacks,
limit_val_batches=limit_val_batches,
enable_checkpointing=enable_checkpointing,
)
model = DummyModel()
batch_size = 8
train_dl = DataLoader(dataset=dataset, batch_size=batch_size)
val_dl = DataLoader(dataset=dataset, batch_size=batch_size)
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.num_val_batches[0] == limit_val_batches
assert epoch_cb.val_epoch_count == max_epochs
assert epoch_cb.val_batches_seen == limit_val_batches * max_epochs
@pytest.mark.skip()
@pytest.mark.parametrize(
"dataset",
[
RandomDataset(32, 128),
RandomIterableDataset(32, 128),
RandomIterableDatasetWithLen(32, 128),
],
)
def test_datasets_dataloaders_with_limit_num_batches(tmpdir, dataset):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number."""
epoch_cb = Counter()
max_epochs = 2
limit_batches = 10
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=max_epochs,
callbacks=[epoch_cb],
limit_train_batches=limit_batches,
limit_val_batches=limit_batches,
limit_test_batches=limit_batches,
)
model = DummyModel()
batch_size = 8
train_dl = DataLoader(dataset=dataset, batch_size=batch_size)
val_dl = DataLoader(dataset=dataset, batch_size=batch_size)
test_dl = DataLoader(dataset=dataset, batch_size=batch_size)
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.num_training_batches == limit_batches
assert trainer.num_val_batches[0] == limit_batches
assert epoch_cb.train_epoch_count == max_epochs
assert epoch_cb.train_batches_seen == limit_batches * max_epochs
assert epoch_cb.val_epoch_count == max_epochs
assert epoch_cb.val_batches_seen == limit_batches * max_epochs
trainer.test(model, dataloaders=test_dl)
assert trainer.num_test_batches[0] == limit_batches
assert epoch_cb.test_epoch_count == 1
@pytest.mark.skip()
@pytest.mark.parametrize(
["limit_train_batches", "limit_val_batches", "limit_test_batches"],
[(1.0, 1.0, 1.0), (0.2, 0.4, 0.4)],
)
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent."""
2021-12-17 18:10:52 +00:00
model = MultiEvalDataLoaderModel()
# train, multiple val and multiple test passed with percent_check
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)
trainer.fit(model)
expected_train_batches = int(len(trainer.train_dataloader) * limit_train_batches)
expected_val_batches = [int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders]
assert trainer.num_training_batches == expected_train_batches
assert trainer.num_val_batches == expected_val_batches
trainer.test(model)
2021-02-06 15:06:17 +00:00
expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders]
assert trainer.num_test_batches == expected_test_batches
@pytest.mark.parametrize(["limit_train_batches", "limit_val_batches", "limit_test_batches"], [(1, 2, 3), (1, 2, 1e50)])
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number."""
2021-12-17 18:10:52 +00:00
model = MultiEvalDataLoaderModel()
# train, multiple val and multiple test passed with percent_check
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
num_sanity_val_steps=0,
)
with patch.object(
trainer.fit_loop.epoch_loop.val_loop.epoch_loop,
"_evaluation_step",
wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop._evaluation_step,
) as mocked:
trainer.fit(model)
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
assert mocked.call_count == limit_val_batches * len(trainer.val_dataloaders)
with patch.object(
trainer.test_loop.epoch_loop,
"_evaluation_step",
wraps=trainer.test_loop.epoch_loop._evaluation_step,
) as mocked:
trainer.test(model)
test_dataloader_lengths = [len(x) for x in model.test_dataloader()]
if limit_test_batches > 1e10:
# when the limit is greater than the number of test batches it should be the num in loaders
assert trainer.num_test_batches == test_dataloader_lengths
assert mocked.call_count == sum(test_dataloader_lengths)
else:
assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders)
assert mocked.call_count == limit_test_batches * len(trainer.test_dataloaders)
Fix ddp tests + .test() (#2512) * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * fix deprecation warnings * added base tests for tpu * added base tests for tpu * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-07-07 16:24:56 +00:00
@pytest.mark.parametrize("fast_dev_run", [True, 1, 3, -1])
def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run):
"""Verify num_batches for train, val & test dataloaders passed with fast_dev_run."""
2021-12-17 18:10:52 +00:00
model = MultiEvalDataLoaderModel()
trainer_options = dict(default_root_dir=tmpdir, max_epochs=2, fast_dev_run=fast_dev_run)
if fast_dev_run == -1:
with pytest.raises(MisconfigurationException, match="should be >= 0"):
Trainer(**trainer_options)
else:
trainer = Trainer(**trainer_options)
# fast_dev_run is set to True when it is 1
if fast_dev_run == 1:
fast_dev_run = True
assert trainer.fast_dev_run is fast_dev_run
if fast_dev_run is True:
fast_dev_run = 1
assert trainer.limit_train_batches == fast_dev_run
assert trainer.limit_val_batches == fast_dev_run
assert trainer.limit_test_batches == fast_dev_run
assert trainer.num_sanity_val_steps == 0
assert trainer.max_epochs == 1
trainer.fit(model)
assert trainer.enable_validation
assert trainer.num_training_batches == fast_dev_run
assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders)
trainer.test(model)
assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders)
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
"""Verify that dataloaders can be passed to fit."""
2021-12-17 18:10:52 +00:00
model = BoringModel()
eval_dataloader = DataLoader(RandomDataset(32, 64))
trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
# fit model
trainer = Trainer(**trainer_options)
2021-12-17 18:10:52 +00:00
trainer.fit(model, val_dataloaders=eval_dataloader)
assert len(trainer.val_dataloaders) == 1
assert trainer.val_dataloaders[0] == eval_dataloader
2021-12-17 18:10:52 +00:00
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path
trainer.test(dataloaders=eval_dataloader, ckpt_path=ckpt_path)
assert len(trainer.test_dataloaders) == 1
assert trainer.test_dataloaders[0] == eval_dataloader
assert len(trainer.test_dataloaders) == 1
2022-02-28 18:31:18 +00:00
def test_warning_on_zero_len_dataloader(tmpdir):
"""Test that a warning is raised if a zero-length dataloader is defined."""
model = BoringModel()
trainer = Trainer()
train_dataloader = DataLoader(RandomDataset(32, 0))
val_dataloader = DataLoader(RandomDataset(32, 0))
trainer._data_connector.attach_data(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
with pytest.warns(UserWarning, match="Total length of `CombinedLoader` across ranks is zero"):
trainer.reset_train_dataloader(model)
assert trainer.num_training_batches == 0
with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero"):
trainer.reset_val_dataloader(model)
assert trainer.num_val_batches == [0]
@RunIf(skip_windows=True)
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
@pytest.mark.parametrize("stage", ("train", "test", "val"))
@patch("pytorch_lightning.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4)
def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
"""Test that error is raised if dataloader with only a few workers is used."""
model = BoringModel()
train_dl = model.train_dataloader()
Replaces ddp .spawn with subprocess (#2029) * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix
2020-06-01 15:00:32 +00:00
train_dl.num_workers = 0
val_dl = model.val_dataloader()
Replaces ddp .spawn with subprocess (#2029) * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix
2020-06-01 15:00:32 +00:00
val_dl.num_workers = 0
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
with pytest.warns(
UserWarning,
match=f'The dataloader, {stage}_dataloader{" 0" if stage != "train" else ""}, does not have many workers',
):
if stage == "test":
if ckpt_path in ("specific", "best"):
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path
trainer.test(model, dataloaders=train_dl, ckpt_path=ckpt_path)
else:
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
@RunIf(skip_windows=True)
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
@pytest.mark.parametrize("stage", ("train", "test", "val"))
@patch("pytorch_lightning.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4)
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
"""Test that error is raised if dataloader with only a few workers is used."""
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
2021-12-17 18:10:52 +00:00
class CustomModel(MultiEvalDataLoaderModel):
def training_step(self, batch, batch_idx):
return super().training_step(batch["a_b"][0], batch_idx)
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
2021-12-17 18:10:52 +00:00
model = CustomModel()
val_dl = DataLoader(RandomDataset(32, 64))
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
val_dl.num_workers = 0
2021-12-17 18:10:52 +00:00
train_dl = DataLoader(RandomDataset(32, 64))
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
train_dl.num_workers = 0
train_multi_dl = {"a_b": [train_dl, train_dl], "c_d_e": [train_dl, train_dl, train_dl]}
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
val_multi_dl = [val_dl, val_dl]
test_multi_dl = [train_dl, train_dl]
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
with pytest.warns(
UserWarning,
match=f'The dataloader, {stage}_dataloader{" 0" if stage != "train" else ""}, does not have many workers',
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
):
if stage == "test":
if ckpt_path in ("specific", "best"):
trainer.fit(model, train_dataloaders=train_multi_dl, val_dataloaders=val_multi_dl)
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path
trainer.test(model, dataloaders=test_multi_dl, ckpt_path=ckpt_path)
else:
trainer.fit(model, train_dataloaders=train_multi_dl, val_dataloaders=val_multi_dl)
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
class NumpyRandomDataset(Dataset):
2022-02-17 01:27:51 +00:00
# this dataset uses numpy instead of torch to produce random numbers
size = 16
def __getitem__(self, index):
return numpy.random.randint(0, 100, 3)
def __len__(self):
return self.size
def _user_worker_init_fn(_):
pass
def test_auto_add_worker_init_fn():
"""Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used."""
dataset = Mock()
dataloader = DataLoader(dataset)
# without pl.seed_everything()
_auto_add_worker_init_fn(dataloader, 0)
assert dataloader.worker_init_fn is None
# with forcefully avoiding it
seed_everything(0, workers=False)
_auto_add_worker_init_fn(dataloader, 0)
assert dataloader.worker_init_fn is None
# when user already has a worker_init_fn
user_function = _user_worker_init_fn
dataloader.worker_init_fn = user_function
_auto_add_worker_init_fn(dataloader, 0)
assert dataloader.worker_init_fn is user_function
dataloader.worker_init_fn = None
# main use case
seed_everything(0, workers=True)
_auto_add_worker_init_fn(dataloader, 0)
assert dataloader.worker_init_fn is not None
class MultiProcessModel(BoringModel):
def __init__(self):
super().__init__()
self.batches_seen = []
def training_step(self, batch, batch_idx):
self.batches_seen.append(batch)
def training_epoch_end(self, outputs):
world_size = 2
num_samples = NumpyRandomDataset.size
all_batches = torch.cat(self.batches_seen)
all_batches = self.all_gather(all_batches)
assert all_batches.shape[0] == world_size
all_batches = all_batches.view(-1, 3)
assert len(torch.unique(all_batches, dim=0)) == num_samples
@RunIf(min_cuda_gpus=2)
def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch):
"""Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training."""
dataset = NumpyRandomDataset()
num_workers = 2
batch_size = 2
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
seed_everything(0, workers=True)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp_spawn")
model = MultiProcessModel()
model.val_dataloader = None
trainer.fit(model, train_dataloaders=dataloader)
def test_warning_with_small_dataloader_and_logging_interval(tmpdir):
"""Test that a warning message is shown if the dataloader length is too short for the chosen logging
interval."""
model = BoringModel()
dataloader = DataLoader(RandomDataset(32, length=10))
model.train_dataloader = lambda: dataloader
with pytest.warns(UserWarning, match=r"The number of training batches \(10\) is smaller than the logging interval"):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=11, logger=CSVLogger(tmpdir))
trainer.fit(model)
with pytest.warns(UserWarning, match=r"The number of training batches \(1\) is smaller than the logging interval"):
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1, logger=CSVLogger(".")
)
trainer.fit(model)
def test_warning_with_iterable_dataset_and_len(tmpdir):
"""Tests that a warning message is shown when an IterableDataset defines `__len__`."""
model = BoringModel()
original_dataset = model.train_dataloader().dataset
class IterableWithoutLen(IterableDataset):
def __iter__(self):
return iter(original_dataset)
class IterableWithLen(IterableWithoutLen):
def __len__(self):
return len(original_dataset)
# with __len__ defined
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert has_len_all_ranks(dataloader, trainer.strategy, model)
assert has_iterable_dataset(dataloader)
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
trainer.validate(model, dataloaders=[dataloader])
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
trainer.test(model, dataloaders=[dataloader])
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
trainer.predict(model, dataloaders=[dataloader])
# without __len__ defined
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
assert not has_len_all_ranks(dataloader, trainer.strategy, model)
assert has_iterable_dataset(dataloader)
trainer.validate(model, dataloaders=dataloader)
trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=[dataloader])
trainer.test(model, dataloaders=dataloader)
trainer.predict(model, dataloaders=dataloader)
@pytest.mark.parametrize("yield_at_all", (False, True))
def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all):
"""Test that the training loop skips execution if the iterator is empty from the start."""
class TestDataset(IterableDataset):
def __init__(self, gen):
self.gen = gen
def __iter__(self):
return iter(self.gen())
class TestModel(BoringModel):
def gen(self):
# produce data in epoch 0, no data otherwise
if yield_at_all and self.current_epoch == 0:
yield torch.rand(32)
yield torch.rand(32)
yield torch.rand(32)
model = TestModel()
train_dataloader = DataLoader(TestDataset(model.gen), batch_size=2)
trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=2,
enable_model_summary=False,
)
trainer.fit(model, train_dataloaders=train_dataloader)
assert trainer.global_step == 2 * yield_at_all
# even though the generator might not yield any data, the fit_loop still advances so the
# current epoch gets increased
assert trainer.current_epoch == 2
class DistribSamplerCallback(Callback):
def __init__(self, expected_seeds=(0, 0, 0)):
self.expected_seed = expected_seeds
def on_train_start(self, trainer, pl_module):
train_sampler = trainer.train_dataloader.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle
2021-07-23 04:03:20 +00:00
assert train_sampler.seed == self.expected_seed[0]
def on_validation_start(self, trainer, pl_module):
val_sampler = trainer.val_dataloaders[0].sampler
assert isinstance(val_sampler, DistributedSampler)
assert not val_sampler.shuffle
2021-07-23 04:03:20 +00:00
assert val_sampler.seed == self.expected_seed[1]
def on_test_start(self, trainer, pl_module):
test_sampler = trainer.test_dataloaders[0].sampler
assert isinstance(test_sampler, DistributedSampler)
assert not test_sampler.shuffle
2021-07-23 04:03:20 +00:00
assert test_sampler.seed == self.expected_seed[2]
@RunIf(min_cuda_gpus=2, skip_windows=True)
def test_dataloader_distributed_sampler(tmpdir):
"""Test DistributedSampler and it's arguments for DDP backend."""
seed_everything(123)
2021-12-17 18:10:52 +00:00
model = BoringModel()
trainer = Trainer(
accelerator="gpu",
devices=[0, 1],
num_nodes=1,
strategy="ddp_spawn",
default_root_dir=tmpdir,
max_steps=1,
callbacks=[DistribSamplerCallback(expected_seeds=(123, 123, 123))],
)
trainer.fit(model)
trainer.test(model)
2021-12-17 18:10:52 +00:00
class ModelWithDataLoaderDistributedSampler(BoringModel):
def train_dataloader(self):
dataloader = super().train_dataloader()
dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True, seed=11)
2021-12-17 18:10:52 +00:00
return DataLoader(dataloader.dataset, batch_size=32, drop_last=False, sampler=dist_sampler, shuffle=False)
@RunIf(min_cuda_gpus=2, skip_windows=True)
def test_dataloader_distributed_sampler_already_attached(tmpdir):
"""Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on
dataloader."""
seed_everything(123)
model = ModelWithDataLoaderDistributedSampler()
trainer = Trainer(
accelerator="gpu",
devices=[0, 1],
num_nodes=1,
strategy="ddp_spawn",
default_root_dir=tmpdir,
max_steps=100,
callbacks=[DistribSamplerCallback(expected_seeds=(11, 123, 0))],
replace_sampler_ddp=True,
)
trainer.fit(model)
assert trainer.state.finished, "DDP Training failed"
@pytest.mark.parametrize(
["multiple_trainloader_mode", "num_training_batches"],
2021-12-17 18:10:52 +00:00
[("min_size", 16), ("max_size_cycle", 64)],
)
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches):
2022-02-17 01:27:51 +00:00
"""Integration test for multiple train loaders."""
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
2021-12-17 18:10:52 +00:00
class CustomBoringModel(BoringModel):
def train_dataloader(self):
loaders_a_b = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 16))]
loaders_c_d_e = [
DataLoader(RandomDataset(32, 16)),
DataLoader(RandomDataset(32, 16)),
DataLoader(RandomDataset(32, 16)),
]
# Dict[str, List[DataLoader]]
loaders = {"a_b": loaders_a_b, "c_d_e": loaders_c_d_e}
return loaders
def training_step(self, batch, batch_idx):
assert len(batch) == 2
assert len(batch["a_b"]) == 2
assert len(batch["c_d_e"]) == 3
return super().training_step(batch["a_b"][0], batch_idx)
model = CustomBoringModel()
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, multiple_trainloader_mode=multiple_trainloader_mode)
trainer.fit(model)
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
# verify the num_training_batches according to the multiple_trainloader_mode
assert num_training_batches == trainer.num_training_batches
@pytest.mark.parametrize("check_interval", [50, 1.0])
2021-12-17 18:10:52 +00:00
@pytest.mark.parametrize("dataloader_wrapper", [CustomNotImplementedErrorDataloader, CustomInfDataloader])
def test_train_dataloader_not_implemented_error(tmpdir, check_interval, dataloader_wrapper):
Bugfix/_has_len (#2307) * deal with NotImplementedError raised by torchtext * deal with NotImplementedError raised by torchtext * Added tests for dataloader which raise NotImplementedError in __len__() * Fixed some typos * enabled tests for dataloader raising NotImplementedError in __len__ and corrected match string for raised exception * deleted empty line for style compliance * refactored CustomNotImplementedErrorDataloader to derive from CustomInfDataloader * enabled reduced number of not_implemented_error dataloader test to reduce runtime for continuous integration * reduced test number of not_implemented_error dataloader test further to reduce test time * reduced test number of not_implemented_error dataloader test to one to reduce test time * disabled all not_implemented_error dataloader test to see if test pass in time * added __next__ with a reduced number (5) of elements after which CustomNotImplementedErrorDataloader stops to speedup test. * enabling all not_implemented_error dataloader test * added brief description of change and relation of torchtext * CustomNotImplementedErrorDataloader reduced number of batches served to 2. * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Disable parallelism in dataloader Suspect that it might cause pytest to hang more frequent * added max_steps=None to Trainer in not_implemented_error dataloader tests * rearranged not_implemented_error test in file to group them together * disabled parallel data loading Reason: testing if that stops the test framework from hanging. * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-06-26 13:31:08 +00:00
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
2021-12-17 18:10:52 +00:00
class CustomBoringModel(BoringModel):
def train_dataloader(self):
return dataloader_wrapper(DataLoader(RandomDataset(32, 64)))
def val_dataloader(self):
return dataloader_wrapper(DataLoader(RandomDataset(32, 64)))
Bugfix/_has_len (#2307) * deal with NotImplementedError raised by torchtext * deal with NotImplementedError raised by torchtext * Added tests for dataloader which raise NotImplementedError in __len__() * Fixed some typos * enabled tests for dataloader raising NotImplementedError in __len__ and corrected match string for raised exception * deleted empty line for style compliance * refactored CustomNotImplementedErrorDataloader to derive from CustomInfDataloader * enabled reduced number of not_implemented_error dataloader test to reduce runtime for continuous integration * reduced test number of not_implemented_error dataloader test further to reduce test time * reduced test number of not_implemented_error dataloader test to one to reduce test time * disabled all not_implemented_error dataloader test to see if test pass in time * added __next__ with a reduced number (5) of elements after which CustomNotImplementedErrorDataloader stops to speedup test. * enabling all not_implemented_error dataloader test * added brief description of change and relation of torchtext * CustomNotImplementedErrorDataloader reduced number of batches served to 2. * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Disable parallelism in dataloader Suspect that it might cause pytest to hang more frequent * added max_steps=None to Trainer in not_implemented_error dataloader tests * rearranged not_implemented_error test in file to group them together * disabled parallel data loading Reason: testing if that stops the test framework from hanging. * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-06-26 13:31:08 +00:00
2021-12-17 18:10:52 +00:00
model = CustomBoringModel()
2021-02-06 15:06:17 +00:00
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=check_interval)
trainer.fit(model)
Bugfix/_has_len (#2307) * deal with NotImplementedError raised by torchtext * deal with NotImplementedError raised by torchtext * Added tests for dataloader which raise NotImplementedError in __len__() * Fixed some typos * enabled tests for dataloader raising NotImplementedError in __len__ and corrected match string for raised exception * deleted empty line for style compliance * refactored CustomNotImplementedErrorDataloader to derive from CustomInfDataloader * enabled reduced number of not_implemented_error dataloader test to reduce runtime for continuous integration * reduced test number of not_implemented_error dataloader test further to reduce test time * reduced test number of not_implemented_error dataloader test to one to reduce test time * disabled all not_implemented_error dataloader test to see if test pass in time * added __next__ with a reduced number (5) of elements after which CustomNotImplementedErrorDataloader stops to speedup test. * enabling all not_implemented_error dataloader test * added brief description of change and relation of torchtext * CustomNotImplementedErrorDataloader reduced number of batches served to 2. * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Disable parallelism in dataloader Suspect that it might cause pytest to hang more frequent * added max_steps=None to Trainer in not_implemented_error dataloader tests * rearranged not_implemented_error test in file to group them together * disabled parallel data loading Reason: testing if that stops the test framework from hanging. * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-06-26 13:31:08 +00:00
# verify training completed
assert trainer.state.finished, f"Training failed with {trainer.state}"
Bugfix/_has_len (#2307) * deal with NotImplementedError raised by torchtext * deal with NotImplementedError raised by torchtext * Added tests for dataloader which raise NotImplementedError in __len__() * Fixed some typos * enabled tests for dataloader raising NotImplementedError in __len__ and corrected match string for raised exception * deleted empty line for style compliance * refactored CustomNotImplementedErrorDataloader to derive from CustomInfDataloader * enabled reduced number of not_implemented_error dataloader test to reduce runtime for continuous integration * reduced test number of not_implemented_error dataloader test further to reduce test time * reduced test number of not_implemented_error dataloader test to one to reduce test time * disabled all not_implemented_error dataloader test to see if test pass in time * added __next__ with a reduced number (5) of elements after which CustomNotImplementedErrorDataloader stops to speedup test. * enabling all not_implemented_error dataloader test * added brief description of change and relation of torchtext * CustomNotImplementedErrorDataloader reduced number of batches served to 2. * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Disable parallelism in dataloader Suspect that it might cause pytest to hang more frequent * added max_steps=None to Trainer in not_implemented_error dataloader tests * rearranged not_implemented_error test in file to group them together * disabled parallel data loading Reason: testing if that stops the test framework from hanging. * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-06-26 13:31:08 +00:00
2021-12-17 18:10:52 +00:00
@pytest.mark.parametrize(
"stage", [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]
)
@pytest.mark.parametrize("dataloader_wrapper", [CustomNotImplementedErrorDataloader, CustomInfDataloader])
def test_inf_dataloader_raise_error_with_partial_batch_limits(tmpdir, stage, dataloader_wrapper):
"""Test limit_batch error with inf dataloader (e.g. IterableDataset)"""
model = BoringModel()
setattr(
model, f"{stage.dataloader_prefix}_dataloader", lambda: dataloader_wrapper(DataLoader(RandomDataset(32, 64)))
)
trainer_kwargs = {"default_root_dir": tmpdir, "max_epochs": 1, f"limit_{stage.dataloader_prefix}_batches": 0.5}
trainer = Trainer(**trainer_kwargs)
trainer_fn = "fit" if stage == RunningStage.TRAINING else stage.value
Bugfix/_has_len (#2307) * deal with NotImplementedError raised by torchtext * deal with NotImplementedError raised by torchtext * Added tests for dataloader which raise NotImplementedError in __len__() * Fixed some typos * enabled tests for dataloader raising NotImplementedError in __len__ and corrected match string for raised exception * deleted empty line for style compliance * refactored CustomNotImplementedErrorDataloader to derive from CustomInfDataloader * enabled reduced number of not_implemented_error dataloader test to reduce runtime for continuous integration * reduced test number of not_implemented_error dataloader test further to reduce test time * reduced test number of not_implemented_error dataloader test to one to reduce test time * disabled all not_implemented_error dataloader test to see if test pass in time * added __next__ with a reduced number (5) of elements after which CustomNotImplementedErrorDataloader stops to speedup test. * enabling all not_implemented_error dataloader test * added brief description of change and relation of torchtext * CustomNotImplementedErrorDataloader reduced number of batches served to 2. * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Disable parallelism in dataloader Suspect that it might cause pytest to hang more frequent * added max_steps=None to Trainer in not_implemented_error dataloader tests * rearranged not_implemented_error test in file to group them together * disabled parallel data loading Reason: testing if that stops the test framework from hanging. * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-06-26 13:31:08 +00:00
with pytest.raises(MisconfigurationException, match=r"IterableDataset`.*limit_.*_batches\)`.*`1.0` or an int"):
2021-12-17 18:10:52 +00:00
getattr(trainer, trainer_fn)(model)
def test_dataloaders_load_only_once(tmpdir):
model = BoringModel()
tracker = Mock()
model.train_dataloader = Mock(wraps=model.train_dataloader)
model.val_dataloader = Mock(wraps=model.val_dataloader)
model.test_dataloader = Mock(wraps=model.test_dataloader)
tracker.attach_mock(model.train_dataloader, "train_dataloader")
tracker.attach_mock(model.val_dataloader, "val_dataloader")
tracker.attach_mock(model.test_dataloader, "test_dataloader")
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, max_epochs=3)
trainer.fit(model)
model.train_dataloader.assert_called_once()
model.val_dataloader.assert_called_once()
model.test_dataloader.assert_not_called()
assert tracker.mock_calls == [call.val_dataloader(), call.train_dataloader()]
def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
model = BoringModel()
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3
)
tracker = Mock()
model.train_dataloader = Mock(wraps=model.train_dataloader)
model.val_dataloader = Mock(wraps=model.val_dataloader)
model.test_dataloader = Mock(wraps=model.test_dataloader)
tracker.attach_mock(model.train_dataloader, "train_dataloader")
tracker.attach_mock(model.val_dataloader, "val_dataloader")
tracker.attach_mock(model.test_dataloader, "test_dataloader")
trainer.fit(model)
# verify the sequence
expected_sequence = [call.train_dataloader(), call.val_dataloader()]
assert tracker.mock_calls == expected_sequence
@pytest.mark.parametrize(
(
"num_sanity_val_steps, check_val_every_n_epoch, reload_dataloaders_every_n_epochs,"
" train_reload_epochs_expect,val_reload_epochs_expect,val_step_epochs_expect"
),
[
# general case where sanity check reloads the dataloaders for validation on current_epoch=0
(0, 1, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
(1, 1, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
# case where check_val_every_n_epoch < reload_dataloaders_every_n_epochs so expected val_reload_epoch
# and val_step_epoch will be different
(0, 1, 2, [0, 2, 4, 6, 8], [0, 2, 4, 6, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
(1, 1, 2, [0, 2, 4, 6, 8], [2, 4, 6, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
(0, 3, 4, [0, 4, 8], [2, 8], [2, 5, 8]),
(1, 3, 4, [0, 4, 8], [2, 8], [2, 5, 8]),
# case where check_val_every_n_epoch > reload_dataloaders_every_n_epochs so expected val_reload_epoch
# and val_step_epoch will be same
(0, 2, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]),
(1, 2, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]),
(0, 3, 2, [0, 2, 4, 6, 8], [2, 5, 8], [2, 5, 8]),
(1, 3, 2, [0, 2, 4, 6, 8], [2, 5, 8], [2, 5, 8]),
(0, 5, 2, [0, 2, 4, 6, 8], [4, 9], [4, 9]),
(1, 5, 2, [0, 2, 4, 6, 8], [4, 9], [4, 9]),
# case where check_val_every_n_epoch = reload_dataloaders_every_n_epochs so expected val_reload_epoch
# and val_step_epoch will be same
(0, 2, 2, [0, 2, 4, 6, 8], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]),
(1, 2, 2, [0, 2, 4, 6, 8], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]),
],
)
def test_dataloaders_load_every_n_epochs_infrequent_val(
tmpdir,
num_sanity_val_steps,
check_val_every_n_epoch,
reload_dataloaders_every_n_epochs,
train_reload_epochs_expect,
val_reload_epochs_expect,
val_step_epochs_expect,
):
"""Test dataloader reload behavior when infrequently checking validation set (via check_val_every_n_epoch)"""
sanity_val_check_epochs, train_reload_epochs, val_reload_epochs = [], [], []
sanity_val_step_epochs, val_step_epochs = [], []
class TestModel(BoringModel):
def train_dataloader(self):
train_reload_epochs.append(self.current_epoch)
return super().train_dataloader()
def val_dataloader(self):
if self.trainer.sanity_checking:
sanity_val_check_epochs.append(self.current_epoch)
else:
val_reload_epochs.append(self.current_epoch)
return super().val_dataloader()
def validation_step(self, *args, **kwargs):
if self.trainer.sanity_checking:
sanity_val_step_epochs.append(self.current_epoch)
else:
val_step_epochs.append(self.current_epoch)
return super().validation_step(*args, **kwargs)
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
check_val_every_n_epoch=check_val_every_n_epoch,
reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs,
max_epochs=10,
num_sanity_val_steps=num_sanity_val_steps,
)
trainer.fit(model)
# Verify epoch of reloads
sanity_val_check_epochs_expect = [0] if num_sanity_val_steps else []
assert sanity_val_check_epochs == sanity_val_step_epochs == sanity_val_check_epochs_expect
assert train_reload_epochs == train_reload_epochs_expect
assert val_reload_epochs == val_reload_epochs_expect
assert val_step_epochs == val_step_epochs_expect
def test_dataloaders_load_every_n_epochs_frequent_val(tmpdir):
"""Test dataloader reload behavior when frequently checking validation set (via val_check_interval)"""
train_reload_epochs, val_reload_epochs, val_check_epochs = [], [], []
class TestModel(BoringModel):
def train_dataloader(self):
train_reload_epochs.append(self.current_epoch)
return super().train_dataloader()
def val_dataloader(self):
val_reload_epochs.append(self.current_epoch)
return super().val_dataloader()
def validation_epoch_end(self, outputs):
val_check_epochs.append(self.current_epoch)
return super().validation_epoch_end(outputs)
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=0.3,
limit_val_batches=0.3,
val_check_interval=0.3,
reload_dataloaders_every_n_epochs=1,
max_epochs=3,
)
model.test_dataloader = Mock(wraps=model.test_dataloader)
trainer.fit(model)
trainer.test(model)
# Verify epoch of reloads
assert train_reload_epochs == [0, 1, 2]
assert val_reload_epochs == [0, 1, 2]
model.test_dataloader.assert_called_once()
# Verify validation happens 3 times per epoch + 1 for sanity check
assert val_check_epochs == [0, 0, 0, 0, 1, 1, 1, 2, 2, 2]
@pytest.mark.parametrize("n", ["test", -1])
Enables reload of dataloaders on every n epochs from every epoch (#5043) * edit arg to reload_dataloaders_every_n_epoch * init reload_dataloaders_every_n_epoch * edit logic to reload dl * update arg to test datamodule * update arg test dataloader * edit reload dl logic in eval loop * fix var name in reset_train_val_dataloaders * fix error, use current_epoch attribute * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * assert reload_dataloaders_every_n_epochs positive * assert reload_dataloaders_every_n_epochs positive * add trainer property should reload dl * update should reload dl in train loop * condition on should reload dl in eval loop * pep8 * fix update should reload dl in train loop * add test case * replace assertion with misconfig exception * remove unused variable * remove unnecessary checks * replace to BoringModel * remove unrequired comment * deprecate _every_epoch * add deprecated argument to trainer * test case for deprecated arg * remove unrequired assertion in train loop Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * modify misconfig exception for int Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * conv bool to int of depreciated _every_epoch Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * update description of deprecated param Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * update deprecation warning Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * modify argument to int only * fix deprecated test function name Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * merge tests for reload dls * add propery should reload dl * removed and added to trainer property * use property in train loop * remove deprecated test * add deprecated test to new file * test case for exception * update test datamodule every_n_epochs * update trainer docs * update hooks with every_n_epochs * edit format if statement Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * typo in exception * pytest check only misconfig exception * remove unnecessary code in test * remove unnecessary code in deprec test * added match in test * typo in comment * revert to prev, keep only req in context manager * Apply suggestions from code review * docs * rebase * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import: model_helpers instead of model_utils * fix, add reload_dataloaders_every_n_epochs argument to data connector * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add required imports * move deprecated log * add missing import rank_zero_warn * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update varname in should_reload_dl_epoch suggestion from code review * Fix CHANGELOG. Update deprecation versions * Minor change * change property name, mark protected * update property name * update property name * Remove deprecated *_loop.py files * Rename test func * Update CHANGELOG.md * use rank_zero_deprecation * update deprecation message in trainer api docs * test deprecation with real arg name in message * fix typo in trainer docs Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2021-07-07 11:10:08 +00:00
def test_dataloaders_load_every_n_epochs_exception(tmpdir, n):
with pytest.raises(MisconfigurationException, match="should be an int >"):
Trainer(default_root_dir=tmpdir, reload_dataloaders_every_n_epochs=n)
Enables reload of dataloaders on every n epochs from every epoch (#5043) * edit arg to reload_dataloaders_every_n_epoch * init reload_dataloaders_every_n_epoch * edit logic to reload dl * update arg to test datamodule * update arg test dataloader * edit reload dl logic in eval loop * fix var name in reset_train_val_dataloaders * fix error, use current_epoch attribute * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * assert reload_dataloaders_every_n_epochs positive * assert reload_dataloaders_every_n_epochs positive * add trainer property should reload dl * update should reload dl in train loop * condition on should reload dl in eval loop * pep8 * fix update should reload dl in train loop * add test case * replace assertion with misconfig exception * remove unused variable * remove unnecessary checks * replace to BoringModel * remove unrequired comment * deprecate _every_epoch * add deprecated argument to trainer * test case for deprecated arg * remove unrequired assertion in train loop Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * modify misconfig exception for int Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * conv bool to int of depreciated _every_epoch Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * update description of deprecated param Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * update deprecation warning Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * modify argument to int only * fix deprecated test function name Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * merge tests for reload dls * add propery should reload dl * removed and added to trainer property * use property in train loop * remove deprecated test * add deprecated test to new file * test case for exception * update test datamodule every_n_epochs * update trainer docs * update hooks with every_n_epochs * edit format if statement Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * typo in exception * pytest check only misconfig exception * remove unnecessary code in test * remove unnecessary code in deprec test * added match in test * typo in comment * revert to prev, keep only req in context manager * Apply suggestions from code review * docs * rebase * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import: model_helpers instead of model_utils * fix, add reload_dataloaders_every_n_epochs argument to data connector * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add required imports * move deprecated log * add missing import rank_zero_warn * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update varname in should_reload_dl_epoch suggestion from code review * Fix CHANGELOG. Update deprecation versions * Minor change * change property name, mark protected * update property name * update property name * Remove deprecated *_loop.py files * Rename test func * Update CHANGELOG.md * use rank_zero_deprecation * update deprecation message in trainer api docs * test deprecation with real arg name in message * fix typo in trainer docs Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2021-07-07 11:10:08 +00:00
def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir):
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("dummy_val", 5.0)
return super().validation_step(batch, batch_idx)
model = TestModel()
# This callback tests that the evaluation metrics are available by the time we run checkpointing
checkpoint_callback = ModelCheckpoint(monitor="dummy_val", save_top_k=1)
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=0.3,
limit_val_batches=0.3,
num_sanity_val_steps=0,
2021-11-22 19:32:04 +00:00
reload_dataloaders_every_n_epochs=1,
max_epochs=3,
callbacks=[checkpoint_callback],
)
tracker = Mock()
model.train_dataloader = Mock(wraps=model.train_dataloader)
model.val_dataloader = Mock(wraps=model.val_dataloader)
model.test_dataloader = Mock(wraps=model.test_dataloader)
tracker.attach_mock(model.train_dataloader, "train_dataloader")
tracker.attach_mock(model.val_dataloader, "val_dataloader")
tracker.attach_mock(model.test_dataloader, "test_dataloader")
trainer.fit(model)
trainer.test(model)
expected_calls = [
call.train_dataloader(),
call.val_dataloader(),
call.train_dataloader(),
call.val_dataloader(),
call.train_dataloader(),
call.val_dataloader(),
call.test_dataloader(),
]
assert tracker.mock_calls == expected_calls
def test_dataloaders_load_only_once_passed_loaders(tmpdir):
model = BoringModel()
train_dataloader = model.train_dataloader()
val_dataloader = model.val_dataloader()
test_dataloader = model.test_dataloader()
# delete dataloader methods on the model
model.train_dataloader = None
model.val_dataloader = None
model.test_dataloader = None
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, max_epochs=3)
trainer.reset_train_dataloader = Mock(wraps=trainer.reset_train_dataloader)
trainer.reset_val_dataloader = Mock(wraps=trainer.reset_val_dataloader)
trainer.reset_test_dataloader = Mock(wraps=trainer.reset_test_dataloader)
tracker = Mock()
tracker.attach_mock(trainer.reset_train_dataloader, "reset_train_dataloader")
tracker.attach_mock(trainer.reset_val_dataloader, "reset_val_dataloader")
tracker.attach_mock(trainer.reset_test_dataloader, "reset_test_dataloader")
trainer.fit(model, train_dataloader, val_dataloader)
trainer.test(model, dataloaders=test_dataloader)
trainer.reset_train_dataloader.assert_called_once()
trainer.reset_val_dataloader.assert_called_once()
trainer.reset_test_dataloader.assert_called_once()
assert tracker.mock_calls == [
call.reset_val_dataloader(),
call.reset_train_dataloader(model),
call.reset_test_dataloader(),
]
def test_dataloaders_reset_and_attach(tmpdir):
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching
the new one."""
# the assertions compare the datasets and not dataloaders since we patch and replace the samplers
dataloader_0 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_1 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_2 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_3 = DataLoader(dataset=RandomDataset(32, 64))
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
)
# 1st fit
trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1)
assert trainer.train_dataloader.loaders.dataset is dataloader_0.dataset
assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset
# 2nd fit
trainer.fit_loop.max_steps += 1
trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
assert trainer.train_dataloader.loaders.dataset is dataloader_2.dataset
assert trainer.val_dataloaders[0].dataset is dataloader_3.dataset
# 1st validate
trainer.validate(model, dataloaders=dataloader_0)
assert trainer.val_dataloaders[0].dataset is dataloader_0.dataset
# 2nd validate
trainer.validate(model, dataloaders=dataloader_1)
assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset
# 1st test
trainer.test(model, dataloaders=dataloader_0)
assert trainer.test_dataloaders[0].dataset is dataloader_0.dataset
# 2nd test
trainer.test(model, dataloaders=dataloader_1)
assert trainer.test_dataloaders[0].dataset is dataloader_1.dataset
# 1st predict
trainer.predict(model, dataloaders=dataloader_0)
assert trainer.predict_dataloaders[0].dataset is dataloader_0.dataset
# 2nd predict
trainer.predict(model, dataloaders=dataloader_1)
assert trainer.predict_dataloaders[0].dataset is dataloader_1.dataset
@pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"])
def test_correct_dataloader_idx_in_hooks(tmpdir, multiple_trainloader_mode):
"""Check the correct dataloader_idx inside hooks."""
class CustomBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.val_call_count = 0
self.test_call_count = 0
def assert_dataloader_idx_hook(self, dataloader_idx):
if self.trainer.training:
move batch to device before sending it to hooks (#7378) * update train step * test * x * limits * val * typeo * x * x * step * min gpus * run all loops * x * limit test * profiler * clean up accelerator code * move files * rename * move tests * changelog * reorder callbacks and model hooks * add test description * replace unneccessary method * fix chlog * adjust batch_to_device for DP Plugin * update tests for dataloader idx * unused imports * hook change * switch None * clear memory * change to None * None * None * memory savings * remove redundant todo * hack * cheat * Revert "cheat" This reverts commit a8433bd0b4bd35f218993335f7d4ff18977ae423. * Revert "hack" This reverts commit 43a6d1edeb62a15ac69ef69ef2352581ba1947a5. * update new epoch loop * remove from old loop code * update chlog * update hook test * changelog * teardown * integrate changes in new eval loop * fix hook calls * add prediction step * bad merge * Revert "bad merge" This reverts commit 488080863cf012dcf04446be3b7d973b7340687e. * fix train batch hook test * rm -rf _notebooks * update chlog * release memory * fix type * notebooks mess * debug * Revert "debug" This reverts commit eec4ee2f77b5eb39965211a250598ed5d2320e88. * teardown * fix teardown bug * debug * x * debug * Revert "debug" This reverts commit a6e61019462b80d09d31b65bed289fa6e4dd15f6. Revert "debug" This reverts commit 5ddeaec06911e96730aade1be6ee71d097b46b9a. debug debug Revert "debug" This reverts commit 605be746f7daedf265b2c05a1c153ce543394435. Revert "Revert "debug"" This reverts commit a7612d5410409ed886cfb609457349ecf44cbfa8. debug x x x s tol x tol * Fix changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
2021-07-05 08:31:39 +00:00
assert dataloader_idx == 0
elif self.trainer.validating:
assert dataloader_idx == (0 if self.val_call_count <= 5 else 1)
elif self.trainer.testing:
assert dataloader_idx == (0 if self.test_call_count <= 5 else 1)
def transfer_batch_to_device(self, batch, device, dataloader_idx):
self.assert_dataloader_idx_hook(dataloader_idx)
return super().transfer_batch_to_device(batch, device, dataloader_idx)
def on_before_batch_transfer(self, batch, dataloader_idx):
# incrementing here since this is the first hook called at each step
if self.trainer.validating:
self.val_call_count += 1
elif self.trainer.testing:
self.test_call_count += 1
self.assert_dataloader_idx_hook(dataloader_idx)
return super().on_before_batch_transfer(batch, dataloader_idx)
def on_after_batch_transfer(self, batch, dataloader_idx):
self.assert_dataloader_idx_hook(dataloader_idx)
return super().on_after_batch_transfer(batch, dataloader_idx)
def training_step(self, batch, batch_idx):
return super().training_step(batch["a"], batch_idx)
def validation_step(self, batch, batch_idx, dataloader_idx):
self.assert_dataloader_idx_hook(dataloader_idx)
out = super().validation_step(batch, batch_idx)
loss = out.pop("x")
out[f"val_loss_{dataloader_idx}"] = loss
return out
def test_step(self, batch, batch_idx, dataloader_idx):
self.assert_dataloader_idx_hook(dataloader_idx)
out = super().test_step(batch, batch_idx)
loss = out.pop("y")
out[f"test_loss_{dataloader_idx}"] = loss
return out
def predict(self, batch, batch_idx, dataloader_idx):
self.assert_dataloader_idx_hook(dataloader_idx)
return super().predict(batch, batch_idx, dataloader_idx)
def assert_epoch_end_outputs(self, outputs, mode):
assert len(outputs) == 2
assert all(f"{mode}_loss_0" in x for x in outputs[0])
assert all(f"{mode}_loss_1" in x for x in outputs[1])
def validation_epoch_end(self, outputs):
self.assert_epoch_end_outputs(outputs, mode="val")
def test_epoch_end(self, outputs):
self.assert_epoch_end_outputs(outputs, mode="test")
def train_dataloader(self):
return {"a": DataLoader(RandomDataset(32, 64)), "b": DataLoader(RandomDataset(32, 64))}
def val_dataloader(self):
return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
def test_dataloader(self):
return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
def predict_dataloader(self):
return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=5, multiple_trainloader_mode=multiple_trainloader_mode)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
trainer.test(model)
preds = trainer.predict(model)
assert len(preds) == 2
assert all(len(x) == 5 for x in preds)
def test_request_dataloader(tmpdir):
"""This test asserts dataloader can be wrapped."""
class DataLoaderWrapper:
def __init__(self, loader):
self.loader = loader
self._iter = iter(self.loader)
def __iter__(self):
self._iter = iter(self.loader)
return self._iter
def __next__(self):
return next(self._iter)
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.on_train_batch_start_called = False
self.on_val_batch_start_called = False
def train_dataloader(self):
loader = super().train_dataloader()
return DataLoaderWrapper(loader)
def on_train_batch_start(self, batch, batch_idx: int) -> None:
assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper)
self.on_train_batch_start_called = True
def val_dataloader(self):
loader = super().val_dataloader()
return DataLoaderWrapper(loader)
def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None:
assert isinstance(self.trainer.val_dataloaders[0], DataLoaderWrapper)
self.on_val_batch_start_called = True
trainer = Trainer(
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1
)
model = TestModel()
trainer.fit(model)
trainer.test(model)
assert model.on_train_batch_start_called
assert model.on_val_batch_start_called
@pytest.mark.parametrize("num_loaders", [1, 2])
def test_multiple_dataloaders_with_random_sampler_overfit_batches(num_loaders, tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
for idx in range(num_loaders):
assert isinstance(self.trainer.train_dataloader.loaders[idx].loader.sampler, SequentialSampler)
return super().training_step(batch[0], batch_idx)
def _create_dataloader(self):
ds = RandomDataset(32, 64)
return DataLoader(ds, sampler=RandomSampler(ds))
def train_dataloader(self):
return [self._create_dataloader() for _ in range(num_loaders)]
validation_step = None
trainer = Trainer(default_root_dir=tmpdir, overfit_batches=1.0, max_epochs=1)
trainer.fit(TestModel())