1162 lines
41 KiB
Python
1162 lines
41 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from unittest import mock
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
from torch.utils.data.dataloader import DataLoader
|
|
from torch.utils.data.dataset import IterableDataset, Subset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.data.sampler import SequentialSampler
|
|
|
|
import tests.helpers.pipelines as tpipes
|
|
from pytorch_lightning import Callback, Trainer
|
|
from pytorch_lightning.trainer.states import TrainerState
|
|
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from tests.base import EvalModelTemplate
|
|
from tests.helpers.boring_model import BoringModel, RandomDataset
|
|
from tests.helpers.runif import RunIf
|
|
|
|
|
|
def test_fit_train_loader_only(tmpdir):
|
|
model = EvalModelTemplate()
|
|
train_dataloader = model.train_dataloader()
|
|
|
|
model.train_dataloader = None
|
|
model.val_dataloader = None
|
|
model.test_dataloader = None
|
|
|
|
model.validation_step = None
|
|
model.validation_epoch_end = None
|
|
|
|
model.test_step = None
|
|
model.test_epoch_end = None
|
|
|
|
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
|
|
trainer.fit(model, train_dataloader=train_dataloader)
|
|
|
|
|
|
def test_fit_val_loader_only(tmpdir):
|
|
model = EvalModelTemplate()
|
|
train_dataloader = model.train_dataloader()
|
|
val_dataloader = model.val_dataloader()
|
|
|
|
model.train_dataloader = None
|
|
model.val_dataloader = None
|
|
model.test_dataloader = None
|
|
|
|
model.test_step = None
|
|
model.test_epoch_end = None
|
|
|
|
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
|
|
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)
|
|
|
|
|
|
@pytest.mark.parametrize("dataloader_options", [
|
|
dict(val_check_interval=10000),
|
|
])
|
|
def test_dataloader_config_errors_runtime(tmpdir, dataloader_options):
|
|
model = EvalModelTemplate()
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
**dataloader_options,
|
|
)
|
|
with pytest.raises(ValueError):
|
|
# fit model
|
|
trainer.fit(model)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"dataloader_options", [
|
|
dict(limit_train_batches=-0.1),
|
|
dict(limit_train_batches=1.2),
|
|
dict(limit_val_batches=-0.1),
|
|
dict(limit_val_batches=1.2),
|
|
dict(limit_test_batches=-0.1),
|
|
dict(limit_test_batches=1.2),
|
|
dict(val_check_interval=-0.1),
|
|
dict(val_check_interval=1.2),
|
|
dict(overfit_batches=-0.1),
|
|
dict(overfit_batches=1.2),
|
|
]
|
|
)
|
|
def test_dataloader_config_errors_init(tmpdir, dataloader_options):
|
|
with pytest.raises(MisconfigurationException, match='passed invalid value'):
|
|
Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
**dataloader_options,
|
|
)
|
|
|
|
|
|
def test_multiple_val_dataloader(tmpdir):
|
|
"""Verify multiple val_dataloader."""
|
|
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__multiple
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
|
|
# fit model
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_val_batches=0.1,
|
|
limit_train_batches=1.0,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
# verify training completed
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
# verify there are 2 val loaders
|
|
assert len(trainer.val_dataloaders) == 2, 'Multiple val_dataloaders not initiated properly'
|
|
|
|
# make sure predictions are good for each val set
|
|
for dataloader in trainer.val_dataloaders:
|
|
tpipes.run_prediction_eval_model_template(trained_model=model, dataloader=dataloader)
|
|
|
|
|
|
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
|
|
def test_multiple_eval_dataloader(tmpdir, ckpt_path):
|
|
"""Verify multiple evaluation dataloaders."""
|
|
|
|
class MultipleTestDataloaderModel(EvalModelTemplate):
|
|
|
|
def test_dataloader(self):
|
|
return [self.dataloader(train=False), self.dataloader(train=False)]
|
|
|
|
def test_step(self, *args, **kwargs):
|
|
return super().test_step__multiple_dataloaders(*args, **kwargs)
|
|
|
|
def val_dataloader(self):
|
|
return self.test_dataloader()
|
|
|
|
def validation_step(self, *args, **kwargs):
|
|
output = self.test_step(*args, **kwargs)
|
|
return {k.replace("test_", "val_"): v for k, v in output.items()}
|
|
|
|
model = MultipleTestDataloaderModel()
|
|
|
|
# fit model
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_val_batches=10,
|
|
limit_train_batches=100,
|
|
)
|
|
trainer.fit(model)
|
|
if ckpt_path == 'specific':
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
|
|
trainer.validate(ckpt_path=ckpt_path, verbose=False)
|
|
# verify there are 2 loaders
|
|
assert len(trainer.val_dataloaders) == 2
|
|
# make sure predictions are good for each dl
|
|
for dataloader in trainer.val_dataloaders:
|
|
tpipes.run_prediction_eval_model_template(trainer.model, dataloader)
|
|
|
|
trainer.test(ckpt_path=ckpt_path, verbose=False)
|
|
assert len(trainer.test_dataloaders) == 2
|
|
for dataloader in trainer.test_dataloaders:
|
|
tpipes.run_prediction_eval_model_template(trainer.model, dataloader)
|
|
|
|
|
|
def test_train_dataloader_passed_to_fit(tmpdir):
|
|
"""Verify that train dataloader can be passed to fit """
|
|
|
|
# only train passed to fit
|
|
model = EvalModelTemplate()
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_val_batches=0.1,
|
|
limit_train_batches=0.2,
|
|
)
|
|
fit_options = dict(train_dataloader=model.dataloader(train=True))
|
|
trainer.fit(model, **fit_options)
|
|
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
|
|
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
|
|
@pytest.mark.parametrize("n", (1, 2))
|
|
def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
|
|
"""Verify that dataloaders can be passed."""
|
|
|
|
model = EvalModelTemplate()
|
|
if n == 1:
|
|
dataloaders = model.dataloader(train=False)
|
|
else:
|
|
dataloaders = [model.dataloader(train=False)] * 2
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
model.test_step = model.test_step__multiple_dataloaders
|
|
|
|
# train, multiple val and multiple test passed to fit
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_val_batches=0.1,
|
|
limit_train_batches=0.2,
|
|
)
|
|
trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders)
|
|
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
assert len(trainer.val_dataloaders) == n
|
|
|
|
if ckpt_path == 'specific':
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
|
|
trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path)
|
|
trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
|
|
|
|
assert len(trainer.val_dataloaders) == n
|
|
assert len(trainer.test_dataloaders) == n
|
|
|
|
|
|
@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [
|
|
(0.0, 0.0, 0.0),
|
|
(1.0, 1.0, 1.0),
|
|
])
|
|
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
|
|
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__infinite
|
|
model.val_dataloader = model.val_dataloader__infinite
|
|
model.test_dataloader = model.test_dataloader__infinite
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=limit_train_batches,
|
|
limit_val_batches=limit_val_batches,
|
|
limit_test_batches=limit_test_batches,
|
|
)
|
|
|
|
trainer.fit(model)
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf'))
|
|
assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf'))
|
|
|
|
trainer.test(ckpt_path=None)
|
|
assert trainer.num_test_batches[0] == (0 if limit_test_batches == 0.0 else float('inf'))
|
|
|
|
|
|
@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [
|
|
(0, 0, 0),
|
|
(10, 10, 10),
|
|
])
|
|
def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
|
|
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__infinite
|
|
model.val_dataloader = model.val_dataloader__infinite
|
|
model.test_dataloader = model.test_dataloader__infinite
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=limit_train_batches,
|
|
limit_val_batches=limit_val_batches,
|
|
limit_test_batches=limit_test_batches,
|
|
)
|
|
|
|
trainer.fit(model)
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
assert trainer.num_training_batches == limit_train_batches
|
|
assert trainer.num_val_batches[0] == limit_val_batches
|
|
|
|
trainer.test(ckpt_path=None)
|
|
assert trainer.num_test_batches[0] == limit_test_batches
|
|
|
|
|
|
@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [
|
|
(0.0, 0.0, 0.0),
|
|
(0, 0, 0.5),
|
|
(1.0, 1.0, 1.0),
|
|
(0.2, 0.4, 0.4),
|
|
])
|
|
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
|
|
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__multiple_mixed_length
|
|
model.test_dataloader = model.test_dataloader__multiple_mixed_length
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
model.test_step = model.test_step__multiple_dataloaders
|
|
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders
|
|
|
|
# train, multiple val and multiple test passed with percent_check
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=limit_train_batches,
|
|
limit_val_batches=limit_val_batches,
|
|
limit_test_batches=limit_test_batches,
|
|
)
|
|
trainer.fit(model)
|
|
expected_train_batches = int(len(trainer.train_dataloader) * limit_train_batches)
|
|
expected_val_batches = [int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders]
|
|
assert trainer.num_training_batches == expected_train_batches
|
|
assert trainer.num_val_batches == expected_val_batches
|
|
|
|
trainer.test(ckpt_path=None)
|
|
expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders]
|
|
assert trainer.num_test_batches == expected_test_batches
|
|
|
|
|
|
@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [
|
|
(0, 0, 0),
|
|
(1, 2, 3),
|
|
(1, 2, 1e50),
|
|
])
|
|
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
|
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
|
|
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
|
|
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__multiple_mixed_length
|
|
model.test_dataloader = model.test_dataloader__multiple_mixed_length
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
model.test_step = model.test_step__multiple_dataloaders
|
|
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders
|
|
|
|
# train, multiple val and multiple test passed with percent_check
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=limit_train_batches,
|
|
limit_val_batches=limit_val_batches,
|
|
limit_test_batches=limit_test_batches,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
# -------------------------------------------
|
|
# MAKE SURE THE TRAINER SET THE CORRECT VALUES
|
|
# -------------------------------------------
|
|
assert trainer.num_training_batches == limit_train_batches
|
|
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
|
|
trainer.test(ckpt_path=None)
|
|
|
|
# when the limit is greater than the number of test batches it should be the num in loaders
|
|
test_dataloader_lengths = [len(x) for x in model.test_dataloader()]
|
|
if limit_test_batches > 1e10:
|
|
assert trainer.num_test_batches == test_dataloader_lengths
|
|
else:
|
|
assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders)
|
|
|
|
# -------------------------------------------
|
|
# make sure we actually saw the expected num of batches
|
|
# -------------------------------------------
|
|
num_val_dataloaders = len(model.val_dataloader())
|
|
num_test_dataloaders = len(model.test_dataloader())
|
|
if limit_train_batches > 0:
|
|
|
|
# make sure val batches are as expected
|
|
assert len(trainer.dev_debugger.num_seen_val_check_batches) == num_val_dataloaders
|
|
for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_val_check_batches.items():
|
|
assert num_batches == limit_val_batches
|
|
|
|
# make sure test batches are as expected
|
|
assert len(trainer.dev_debugger.num_seen_test_check_batches) == num_test_dataloaders
|
|
for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_test_check_batches.items():
|
|
if limit_test_batches > 1e10:
|
|
assert num_batches == test_dataloader_lengths[dataloader_idx]
|
|
else:
|
|
assert num_batches == limit_test_batches
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
|
@pytest.mark.parametrize('fast_dev_run', [True, 1, 3, -1, 'temp'])
|
|
def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run):
|
|
"""
|
|
Verify num_batches for train, val & test dataloaders passed with fast_dev_run
|
|
"""
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__multiple_mixed_length
|
|
model.test_dataloader = model.test_dataloader__multiple_mixed_length
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
model.test_step = model.test_step__multiple_dataloaders
|
|
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders
|
|
|
|
trainer_options = dict(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=2,
|
|
fast_dev_run=fast_dev_run,
|
|
)
|
|
|
|
if fast_dev_run == 'temp':
|
|
with pytest.raises(MisconfigurationException, match='either a bool or an int'):
|
|
Trainer(**trainer_options)
|
|
elif fast_dev_run == -1:
|
|
with pytest.raises(MisconfigurationException, match='should be >= 0'):
|
|
Trainer(**trainer_options)
|
|
else:
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
# fast_dev_run is set to True when it is 1
|
|
if fast_dev_run == 1:
|
|
fast_dev_run = True
|
|
|
|
assert trainer.fast_dev_run is fast_dev_run
|
|
|
|
if fast_dev_run is True:
|
|
fast_dev_run = 1
|
|
|
|
assert trainer.limit_train_batches == fast_dev_run
|
|
assert trainer.limit_val_batches == fast_dev_run
|
|
assert trainer.limit_test_batches == fast_dev_run
|
|
assert trainer.num_sanity_val_steps == 0
|
|
assert trainer.max_epochs == 1
|
|
|
|
trainer.fit(model)
|
|
assert not trainer.disable_validation
|
|
assert trainer.num_training_batches == fast_dev_run
|
|
assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders)
|
|
|
|
trainer.test(ckpt_path=None)
|
|
assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders)
|
|
|
|
# verify sanity check batches match as expected
|
|
num_val_dataloaders = len(model.val_dataloader())
|
|
assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders
|
|
|
|
|
|
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
|
|
def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
|
|
"""Verify that dataloaders can be passed to fit"""
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
trainer_options = dict(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_val_batches=0.1,
|
|
limit_train_batches=0.2,
|
|
)
|
|
|
|
# fit model
|
|
trainer = Trainer(**trainer_options)
|
|
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
# fit model
|
|
trainer = Trainer(**trainer_options)
|
|
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
if ckpt_path == 'specific':
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path)
|
|
|
|
assert len(trainer.val_dataloaders) == 1, \
|
|
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
|
|
assert len(trainer.test_dataloaders) == 1, \
|
|
f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
|
|
|
|
|
|
def test_train_inf_dataloader_error(tmpdir):
|
|
"""Test inf train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__infinite
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_val_inf_dataloader_error(tmpdir):
|
|
"""Test inf train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__infinite
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_test_inf_dataloader_error(tmpdir):
|
|
"""Test inf train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.test_dataloader = model.test_dataloader__infinite
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
|
trainer.test(model)
|
|
|
|
|
|
@pytest.mark.parametrize('check_interval', [50, 1.0])
|
|
def test_inf_train_dataloader(tmpdir, check_interval):
|
|
"""Test inf train data loader (e.g. IterableDataset)"""
|
|
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__infinite
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
val_check_interval=check_interval,
|
|
)
|
|
trainer.fit(model)
|
|
# verify training completed
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
|
|
@pytest.mark.parametrize('check_interval', [1.0])
|
|
def test_inf_val_dataloader(tmpdir, check_interval):
|
|
"""Test inf val data loader (e.g. IterableDataset)"""
|
|
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__infinite
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
val_check_interval=check_interval,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
# verify training completed
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
|
|
def test_error_on_zero_len_dataloader(tmpdir):
|
|
""" Test that error is raised if a zero-length dataloader is defined """
|
|
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__zero_length
|
|
|
|
# fit model
|
|
with pytest.raises(ValueError):
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=0.1,
|
|
limit_val_batches=0.1,
|
|
limit_test_batches=0.1,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
|
|
@RunIf(skip_windows=True)
|
|
@pytest.mark.parametrize('ckpt_path', (None, 'best', 'specific'))
|
|
@pytest.mark.parametrize('stage', ('train', 'test', 'val'))
|
|
@patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4)
|
|
def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
|
|
""" Test that error is raised if dataloader with only a few workers is used """
|
|
|
|
model = BoringModel()
|
|
|
|
train_dl = model.train_dataloader()
|
|
train_dl.num_workers = 0
|
|
|
|
val_dl = model.val_dataloader()
|
|
val_dl.num_workers = 0
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_val_batches=0.1,
|
|
limit_train_batches=0.2,
|
|
)
|
|
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers'
|
|
):
|
|
if stage == 'test':
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path
|
|
trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path)
|
|
else:
|
|
trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)
|
|
|
|
|
|
@RunIf(skip_windows=True)
|
|
@pytest.mark.parametrize('ckpt_path', (None, 'best', 'specific'))
|
|
@pytest.mark.parametrize('stage', ('train', 'test', 'val'))
|
|
@patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4)
|
|
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
|
|
""" Test that error is raised if dataloader with only a few workers is used """
|
|
|
|
model = EvalModelTemplate()
|
|
model.training_step = model.training_step__multiple_dataloaders
|
|
model.validation_step = model.validation_step__multiple_dataloaders
|
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
|
model.test_step = model.test_step__multiple_dataloaders
|
|
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders
|
|
|
|
val_dl = model.dataloader(train=False)
|
|
val_dl.num_workers = 0
|
|
|
|
train_dl = model.dataloader(train=False)
|
|
train_dl.num_workers = 0
|
|
|
|
train_multi_dl = {'a': train_dl, 'b': train_dl}
|
|
val_multi_dl = [val_dl, val_dl]
|
|
test_multi_dl = [train_dl, train_dl]
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_val_batches=0.1,
|
|
limit_train_batches=0.2,
|
|
)
|
|
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers'
|
|
):
|
|
if stage == 'test':
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path
|
|
trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path)
|
|
else:
|
|
trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl)
|
|
|
|
|
|
def test_warning_with_iterable_dataset_and_len(tmpdir):
|
|
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
|
|
model = EvalModelTemplate()
|
|
original_dataset = model.train_dataloader().dataset
|
|
|
|
class IterableWithLen(IterableDataset):
|
|
|
|
def __iter__(self):
|
|
return iter(original_dataset)
|
|
|
|
def __len__(self):
|
|
return len(original_dataset)
|
|
|
|
dataloader = DataLoader(IterableWithLen(), batch_size=16)
|
|
assert has_len(dataloader)
|
|
assert has_iterable_dataset(dataloader)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_steps=3,
|
|
)
|
|
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
|
|
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
|
|
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
|
|
trainer.test(model, test_dataloaders=[dataloader])
|
|
|
|
|
|
@RunIf(min_gpus=2)
|
|
def test_dataloader_reinit_for_subclass(tmpdir):
|
|
|
|
class CustomDataLoader(torch.utils.data.DataLoader):
|
|
|
|
def __init__(
|
|
self,
|
|
dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
sampler=None,
|
|
batch_sampler=None,
|
|
num_workers=0,
|
|
collate_fn=None,
|
|
pin_memory=False,
|
|
drop_last=False,
|
|
timeout=0,
|
|
worker_init_fn=None,
|
|
dummy_kwarg=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(
|
|
dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last,
|
|
timeout, worker_init_fn
|
|
)
|
|
|
|
self.dummy_kwarg = dummy_kwarg
|
|
|
|
trainer = Trainer(
|
|
gpus=[0, 1],
|
|
num_nodes=1,
|
|
accelerator='ddp_spawn',
|
|
default_root_dir=tmpdir,
|
|
)
|
|
|
|
class CustomDummyObj:
|
|
sampler = None
|
|
|
|
result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True)
|
|
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"
|
|
|
|
dataset = list(range(1000))
|
|
result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True)
|
|
assert isinstance(result, torch.utils.data.DataLoader)
|
|
assert isinstance(result, CustomDataLoader)
|
|
assert hasattr(result, 'dummy_kwarg')
|
|
|
|
# Shuffled DataLoader should also work
|
|
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), shuffle=True)
|
|
assert isinstance(result, torch.utils.data.DataLoader)
|
|
assert isinstance(result, CustomDataLoader)
|
|
assert hasattr(result, 'dummy_kwarg')
|
|
|
|
class CustomSampler(torch.utils.data.Sampler):
|
|
pass
|
|
|
|
# Should raise an error if existing sampler is being replaced
|
|
with pytest.raises(MisconfigurationException, match='DistributedSampler'):
|
|
trainer.auto_add_sampler(
|
|
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), shuffle=True
|
|
)
|
|
|
|
|
|
class DistribSamplerCallback(Callback):
|
|
|
|
def on_train_start(self, trainer, pl_module):
|
|
train_sampler = trainer.train_dataloader.sampler
|
|
assert isinstance(train_sampler, DistributedSampler)
|
|
assert train_sampler.shuffle
|
|
|
|
def on_validation_start(self, trainer, pl_module):
|
|
val_sampler = trainer.val_dataloaders[0].sampler
|
|
assert isinstance(val_sampler, DistributedSampler)
|
|
assert not val_sampler.shuffle
|
|
|
|
def on_test_start(self, trainer, pl_module):
|
|
test_sampler = trainer.test_dataloaders[0].sampler
|
|
assert isinstance(test_sampler, DistributedSampler)
|
|
assert not test_sampler.shuffle
|
|
|
|
|
|
@RunIf(min_gpus=2, skip_windows=True)
|
|
def test_dataloader_distributed_sampler(tmpdir):
|
|
""" Test DistributedSampler and it's arguments for DDP backend """
|
|
|
|
model = EvalModelTemplate()
|
|
trainer = Trainer(
|
|
gpus=[0, 1],
|
|
num_nodes=1,
|
|
accelerator='ddp_spawn',
|
|
default_root_dir=tmpdir,
|
|
max_steps=1,
|
|
callbacks=[DistribSamplerCallback()],
|
|
)
|
|
trainer.fit(model)
|
|
trainer.test(ckpt_path=None)
|
|
|
|
|
|
class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):
|
|
|
|
def train_dataloader(self):
|
|
dataloader = super().train_dataloader()
|
|
dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True)
|
|
return DataLoader(
|
|
dataloader.dataset, batch_size=self.batch_size, drop_last=False, sampler=dist_sampler, shuffle=False
|
|
)
|
|
|
|
|
|
@RunIf(min_gpus=2, skip_windows=True)
|
|
def test_dataloader_distributed_sampler_already_attached(tmpdir):
|
|
""" Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader """
|
|
|
|
model = ModelWithDataLoaderDistributedSampler()
|
|
trainer = Trainer(
|
|
gpus=[0, 1],
|
|
num_nodes=1,
|
|
accelerator='ddp_spawn',
|
|
default_root_dir=tmpdir,
|
|
max_steps=100,
|
|
callbacks=[DistribSamplerCallback()],
|
|
replace_sampler_ddp=True,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.state == TrainerState.FINISHED, "DDP Training failed"
|
|
|
|
|
|
@RunIf(min_gpus=3)
|
|
def test_batch_size_smaller_than_num_gpus(tmpdir):
|
|
# we need at least 3 gpus for this test
|
|
num_gpus = 3
|
|
batch_size = 3
|
|
|
|
class CurrentTestModel(EvalModelTemplate):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
# batch norm doesn't work with batch size 1, we replace it
|
|
self.c_d1_bn = torch.nn.ReLU()
|
|
|
|
def training_step(self, *args, **kwargs):
|
|
output = super().training_step(*args, **kwargs)
|
|
loss = output['loss']
|
|
# we make sure to add some metrics to the output dict,
|
|
# this is essential for this test
|
|
output['progress_bar'] = {'train_loss': loss}
|
|
return output
|
|
|
|
def train_dataloader(self):
|
|
dataloader = super().train_dataloader()
|
|
# construct a dataset with a size that is not divisible by num_gpus
|
|
# therefore the last batch will have a size < num_gpus
|
|
size = num_gpus * batch_size + (num_gpus - 1)
|
|
dataset = Subset(dataloader.dataset, range(size))
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=self.batch_size,
|
|
drop_last=False,
|
|
)
|
|
return dataloader
|
|
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
hparams['batch_size'] = batch_size
|
|
model = CurrentTestModel(**hparams)
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=0.1,
|
|
limit_val_batches=0,
|
|
gpus=num_gpus,
|
|
)
|
|
|
|
# we expect the reduction for the metrics also to happen on the last batch
|
|
# where we will get fewer metrics than gpus
|
|
trainer.fit(model)
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
|
|
@pytest.mark.parametrize(['multiple_trainloader_mode', 'num_training_batches'], [
|
|
pytest.param("min_size", 5),
|
|
pytest.param("max_size_cycle", 10),
|
|
])
|
|
def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches):
|
|
"""Integration test for multple train loaders"""
|
|
model = EvalModelTemplate()
|
|
|
|
model.train_dataloader = model.train_dataloader__multiple_mapping
|
|
# todo: add also `train_dataloader__multiple_sequence`
|
|
model.training_step = model.training_step__multiple_dataloaders
|
|
|
|
trainer = Trainer(
|
|
max_epochs=1,
|
|
default_root_dir=tmpdir,
|
|
multiple_trainloader_mode=multiple_trainloader_mode,
|
|
)
|
|
|
|
assert 1 == trainer.fit(model)
|
|
# verify the num_training_batches according to the multiple_trainloader_mode
|
|
assert num_training_batches == trainer.num_training_batches
|
|
|
|
|
|
@pytest.mark.parametrize('check_interval', [1.0])
|
|
def test_val_dataloader_not_implemented_error(tmpdir, check_interval):
|
|
"""Test not_implemented_error data loader (e.g. IterableDataset)"""
|
|
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__not_implemented_error
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_steps=5,
|
|
max_epochs=1,
|
|
val_check_interval=check_interval,
|
|
)
|
|
trainer.fit(model)
|
|
# verify training completed
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
|
|
@pytest.mark.parametrize('check_interval', [50, 1.0])
|
|
def test_train_dataloader_not_implemented_error(tmpdir, check_interval):
|
|
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
|
|
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__not_implemented_error
|
|
model.val_dataloader = model.val_dataloader__not_implemented_error
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=check_interval)
|
|
trainer.fit(model)
|
|
# verify training completed
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
|
|
def test_train_dataloader_not_implemented_error_failed(tmpdir):
|
|
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.train_dataloader = model.train_dataloader__not_implemented_error
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_val_dataloader_not_implemented_error_failed(tmpdir):
|
|
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.val_dataloader = model.val_dataloader__not_implemented_error
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_test_dataloader_not_implemented_error_failed(tmpdir):
|
|
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
|
|
model = EvalModelTemplate()
|
|
model.test_dataloader = model.test_dataloader__not_implemented_error
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)
|
|
|
|
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
|
|
trainer.test(model)
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
|
def test_dataloaders_load_only_once(tmpdir):
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=0.3,
|
|
limit_val_batches=0.3,
|
|
max_epochs=3,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
assert len(trainer.dev_debugger.val_dataloader_calls) == 1
|
|
assert len(trainer.dev_debugger.test_dataloader_calls) == 0
|
|
assert len(trainer.dev_debugger.train_dataloader_calls) == 1
|
|
|
|
# verify the sequence
|
|
calls = trainer.dev_debugger.dataloader_sequence_calls
|
|
expected_sequence = [
|
|
'val_dataloader',
|
|
'train_dataloader',
|
|
]
|
|
for call, expected in zip(calls, expected_sequence):
|
|
assert call['name'] == expected
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
|
def test_dataloaders_load_only_once_val_interval(tmpdir):
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=10,
|
|
limit_val_batches=10,
|
|
val_check_interval=0.3,
|
|
reload_dataloaders_every_epoch=True,
|
|
max_epochs=3,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
trainer.test()
|
|
|
|
assert len(trainer.dev_debugger.val_dataloader_calls) == 10
|
|
assert len(trainer.dev_debugger.test_dataloader_calls) == 1
|
|
assert len(trainer.dev_debugger.train_dataloader_calls) == 3
|
|
|
|
# verify the sequence
|
|
calls = trainer.dev_debugger.dataloader_sequence_calls
|
|
expected_sequence = [
|
|
'val_dataloader',
|
|
'train_dataloader',
|
|
'val_dataloader',
|
|
'val_dataloader',
|
|
'val_dataloader',
|
|
'train_dataloader',
|
|
'val_dataloader',
|
|
'val_dataloader',
|
|
'val_dataloader',
|
|
'train_dataloader',
|
|
'val_dataloader',
|
|
'val_dataloader',
|
|
'val_dataloader',
|
|
'test_dataloader',
|
|
]
|
|
for call, expected in zip(calls, expected_sequence):
|
|
assert call['name'] == expected
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
|
def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=0.3,
|
|
limit_val_batches=0.3,
|
|
num_sanity_val_steps=0,
|
|
max_epochs=3,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
assert len(trainer.dev_debugger.val_dataloader_calls) == 1
|
|
assert len(trainer.dev_debugger.test_dataloader_calls) == 0
|
|
assert len(trainer.dev_debugger.train_dataloader_calls) == 1
|
|
|
|
# verify the sequence
|
|
calls = trainer.dev_debugger.dataloader_sequence_calls
|
|
expected_sequence = [
|
|
'train_dataloader',
|
|
'val_dataloader',
|
|
]
|
|
for call, expected in zip(calls, expected_sequence):
|
|
assert call['name'] == expected
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
|
def test_dataloaders_load_every_epoch(tmpdir):
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=0.3,
|
|
limit_val_batches=0.3,
|
|
reload_dataloaders_every_epoch=True,
|
|
max_epochs=3,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
trainer.test()
|
|
|
|
assert len(trainer.dev_debugger.val_dataloader_calls) == 4
|
|
assert len(trainer.dev_debugger.train_dataloader_calls) == 3
|
|
assert len(trainer.dev_debugger.test_dataloader_calls) == 1
|
|
|
|
# verify the sequence
|
|
calls = trainer.dev_debugger.dataloader_sequence_calls
|
|
expected_sequence = [
|
|
'val_dataloader',
|
|
'train_dataloader',
|
|
'val_dataloader',
|
|
'train_dataloader',
|
|
'val_dataloader',
|
|
'train_dataloader',
|
|
'val_dataloader',
|
|
'test_dataloader',
|
|
]
|
|
for call, expected in zip(calls, expected_sequence):
|
|
assert call['name'] == expected
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
|
def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir):
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=0.3,
|
|
limit_val_batches=0.3,
|
|
num_sanity_val_steps=0,
|
|
reload_dataloaders_every_epoch=True,
|
|
max_epochs=3,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
trainer.test()
|
|
|
|
assert len(trainer.dev_debugger.val_dataloader_calls) == 3
|
|
assert len(trainer.dev_debugger.train_dataloader_calls) == 3
|
|
assert len(trainer.dev_debugger.test_dataloader_calls) == 1
|
|
|
|
# verify the sequence
|
|
calls = trainer.dev_debugger.dataloader_sequence_calls
|
|
expected_sequence = [
|
|
'train_dataloader',
|
|
'val_dataloader',
|
|
'train_dataloader',
|
|
'val_dataloader',
|
|
'train_dataloader',
|
|
'val_dataloader',
|
|
'test_dataloader',
|
|
]
|
|
for call, expected in zip(calls, expected_sequence):
|
|
assert call['name'] == expected
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
|
def test_dataloaders_load_only_once_passed_loaders(tmpdir):
|
|
|
|
model = EvalModelTemplate()
|
|
train_loader = model.train_dataloader()
|
|
model.train_dataloader = None
|
|
val_loader = model.val_dataloader()
|
|
model.val_dataloader = None
|
|
test_loader = model.test_dataloader()
|
|
model.test_dataloader = None
|
|
|
|
# logger file to get meta
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=0.3,
|
|
limit_val_batches=0.3,
|
|
max_epochs=3,
|
|
)
|
|
trainer.fit(model, train_loader, val_loader)
|
|
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
|
|
|
trainer.test(test_dataloaders=test_loader)
|
|
|
|
assert len(trainer.dev_debugger.val_dataloader_calls) == 1
|
|
assert len(trainer.dev_debugger.test_dataloader_calls) == 1
|
|
assert len(trainer.dev_debugger.train_dataloader_calls) == 1
|
|
|
|
# verify the sequence
|
|
calls = trainer.dev_debugger.dataloader_sequence_calls
|
|
expected_sequence = [
|
|
'val_dataloader',
|
|
'train_dataloader',
|
|
]
|
|
for call, expected in zip(calls, expected_sequence):
|
|
assert call['name'] == expected
|
|
|
|
|
|
def test_replace_sampler_with_multiprocessing_context(tmpdir):
|
|
"""
|
|
This test verifies that replace_sampler conserves multiprocessing context
|
|
"""
|
|
train = RandomDataset(32, 64)
|
|
context = 'spawn'
|
|
train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True)
|
|
trainer = Trainer(
|
|
max_epochs=1,
|
|
progress_bar_refresh_rate=20,
|
|
overfit_batches=5,
|
|
)
|
|
|
|
new_data_loader = trainer.replace_sampler(train, SequentialSampler(train.dataset))
|
|
assert (new_data_loader.multiprocessing_context == train.multiprocessing_context)
|