update tests to not rely on patched dataloaders (#9905)

This commit is contained in:
Adrian Wälchli 2021-10-12 12:45:28 +02:00 committed by GitHub
parent 98c0a110e0
commit b530b7afd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 43 additions and 26 deletions

View File

@ -95,7 +95,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
)
with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"):
new_trainer.fit(model)
new_trainer.fit(model, datamodule=dm)
def test_early_stopping_no_extraneous_invocations(tmpdir):

View File

@ -340,7 +340,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
new_trainer.test(pretrained_model)
pretrained_model.cpu()
dataloaders = model.test_dataloader()
dataloaders = dm.test_dataloader()
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
@ -539,7 +539,7 @@ def test_dp_resume(tmpdir):
# haven't trained with the new loaded model
new_trainer.state.stage = RunningStage.VALIDATING
dataloader = self.train_dataloader()
dataloader = dm.train_dataloader()
tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader)
self.on_pretrain_routine_end_called = True

View File

@ -267,19 +267,19 @@ def test_loader_detaching():
class LoaderTestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert len(model.train_dataloader()) == 10
assert len(self.trainer.train_dataloader.loaders) == 10
return super().training_step(batch, batch_idx)
def validation_step(self, batch, batch_idx):
assert len(model.val_dataloader()) == 10
assert len(self.trainer.val_dataloaders[0]) == 10
return super().validation_step(batch, batch_idx)
def test_step(self, batch, batch_idx):
assert len(model.test_dataloader()) == 10
assert len(self.trainer.test_dataloaders[0]) == 10
return super().test_step(batch, batch_idx)
def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert len(model.predict_dataloader()) == 10
assert len(self.trainer.predict_dataloaders[0]) == 10
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
loader = DataLoader(RandomDataset(32, 10), batch_size=1)

View File

@ -184,7 +184,7 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders
# train, multiple val and multiple test passed to fit
# multiple val dataloaders passed to fit
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders)
@ -195,11 +195,11 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path)
trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
assert len(trainer.val_dataloaders) == n
assert len(trainer.test_dataloaders) == n
trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
assert len(trainer.val_dataloaders) == n
class DummyModel(BoringModel):
def training_step(self, batch, batch_idx):
@ -551,17 +551,15 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
assert trainer.state.finished, f"Training failed with {trainer.state}"
# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"
if ckpt_path == "specific":
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path)
assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"
assert (
len(trainer.test_dataloaders) == 1
), f"`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}"
@ -1313,8 +1311,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir):
def test_dataloaders_reset_and_attach(tmpdir):
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before
attaching the new one."""
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching
the new one."""
# the assertions compare the datasets and not dataloaders since we patch and replace the samplers
dataloader_0 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_1 = DataLoader(dataset=RandomDataset(32, 64))

View File

@ -84,6 +84,7 @@ def test_overfit_batch_limits(tmpdir):
# test train loader applies correct limits
# ------------------------------------------------------
trainer = Trainer(overfit_batches=4)
trainer.data_connector.attach_dataloaders(model=model)
trainer.reset_train_dataloader(model)
assert trainer.num_training_batches == 4
@ -93,6 +94,7 @@ def test_overfit_batch_limits(tmpdir):
assert torch.eq(ya, yb).all()
trainer = Trainer(overfit_batches=0.11)
trainer.data_connector.attach_dataloaders(model=model)
trainer.reset_train_dataloader(model)
# The dataloader should have been overwritten with a Sequential sampler.
assert trainer.train_dataloader is not train_loader
@ -111,7 +113,9 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
# test overfit_batches as percent
# ------------------------------------------------------
loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(split, model=model)
trainer = Trainer(overfit_batches=0.11)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == num_train_samples
# make sure we turned off shuffle for the user
@ -125,23 +129,35 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
# test overfit_batches as int
# ------------------------------------------------------
loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(split, model=model)
trainer = Trainer(overfit_batches=1)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 1
loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(split, model=model)
trainer = Trainer(overfit_batches=5)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 5
# ------------------------------------------------------
# test limit_xxx_batches as percent AND int
# ------------------------------------------------------
if split == RunningStage.VALIDATING:
loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model)
trainer = Trainer(limit_val_batches=0.1)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == int(0.1 * len(val_loader))
loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(split, model=model)
trainer = Trainer(limit_val_batches=10)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 10
else:
loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(split, model=model)
trainer = Trainer(limit_test_batches=0.1)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == int(0.1 * len(test_loader))
loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(split, model=model)
trainer = Trainer(limit_test_batches=10)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 10

View File

@ -220,9 +220,12 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
limit_train_batches=0.2,
auto_scale_batch_size="power",
)
fit_options = dict(train_dataloader=model.dataloader(train=True))
fit_options = dict(train_dataloaders=model.dataloader(train=True))
with pytest.raises(MisconfigurationException):
with pytest.raises(
MisconfigurationException,
match="The batch scaling feature cannot be used with dataloaders passed directly",
):
trainer.tune(model, **fit_options)