update tests to not rely on patched dataloaders (#9905)
This commit is contained in:
parent
98c0a110e0
commit
b530b7afd2
|
@ -95,7 +95,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
|
|||
)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"):
|
||||
new_trainer.fit(model)
|
||||
new_trainer.fit(model, datamodule=dm)
|
||||
|
||||
|
||||
def test_early_stopping_no_extraneous_invocations(tmpdir):
|
||||
|
|
|
@ -340,7 +340,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
|
|||
new_trainer.test(pretrained_model)
|
||||
pretrained_model.cpu()
|
||||
|
||||
dataloaders = model.test_dataloader()
|
||||
dataloaders = dm.test_dataloader()
|
||||
if not isinstance(dataloaders, list):
|
||||
dataloaders = [dataloaders]
|
||||
|
||||
|
@ -539,7 +539,7 @@ def test_dp_resume(tmpdir):
|
|||
# haven't trained with the new loaded model
|
||||
new_trainer.state.stage = RunningStage.VALIDATING
|
||||
|
||||
dataloader = self.train_dataloader()
|
||||
dataloader = dm.train_dataloader()
|
||||
tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader)
|
||||
self.on_pretrain_routine_end_called = True
|
||||
|
||||
|
|
|
@ -267,19 +267,19 @@ def test_loader_detaching():
|
|||
|
||||
class LoaderTestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
assert len(model.train_dataloader()) == 10
|
||||
assert len(self.trainer.train_dataloader.loaders) == 10
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
assert len(model.val_dataloader()) == 10
|
||||
assert len(self.trainer.val_dataloaders[0]) == 10
|
||||
return super().validation_step(batch, batch_idx)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
assert len(model.test_dataloader()) == 10
|
||||
assert len(self.trainer.test_dataloaders[0]) == 10
|
||||
return super().test_step(batch, batch_idx)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
assert len(model.predict_dataloader()) == 10
|
||||
assert len(self.trainer.predict_dataloaders[0]) == 10
|
||||
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
|
||||
|
||||
loader = DataLoader(RandomDataset(32, 10), batch_size=1)
|
||||
|
|
|
@ -184,7 +184,7 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
|
|||
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
||||
model.test_step = model.test_step__multiple_dataloaders
|
||||
|
||||
# train, multiple val and multiple test passed to fit
|
||||
# multiple val dataloaders passed to fit
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
|
||||
trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders)
|
||||
|
||||
|
@ -195,11 +195,11 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
|
|||
ckpt_path = trainer.checkpoint_callback.best_model_path
|
||||
|
||||
trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path)
|
||||
trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
|
||||
|
||||
assert len(trainer.val_dataloaders) == n
|
||||
assert len(trainer.test_dataloaders) == n
|
||||
|
||||
trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
|
||||
assert len(trainer.val_dataloaders) == n
|
||||
|
||||
|
||||
class DummyModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
|
@ -551,17 +551,15 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
|
|||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"
|
||||
|
||||
if ckpt_path == "specific":
|
||||
ckpt_path = trainer.checkpoint_callback.best_model_path
|
||||
trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path)
|
||||
|
||||
assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"
|
||||
assert (
|
||||
len(trainer.test_dataloaders) == 1
|
||||
), f"`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}"
|
||||
|
@ -1313,8 +1311,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir):
|
|||
|
||||
|
||||
def test_dataloaders_reset_and_attach(tmpdir):
|
||||
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before
|
||||
attaching the new one."""
|
||||
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching
|
||||
the new one."""
|
||||
# the assertions compare the datasets and not dataloaders since we patch and replace the samplers
|
||||
dataloader_0 = DataLoader(dataset=RandomDataset(32, 64))
|
||||
dataloader_1 = DataLoader(dataset=RandomDataset(32, 64))
|
||||
|
|
|
@ -84,6 +84,7 @@ def test_overfit_batch_limits(tmpdir):
|
|||
# test train loader applies correct limits
|
||||
# ------------------------------------------------------
|
||||
trainer = Trainer(overfit_batches=4)
|
||||
trainer.data_connector.attach_dataloaders(model=model)
|
||||
trainer.reset_train_dataloader(model)
|
||||
assert trainer.num_training_batches == 4
|
||||
|
||||
|
@ -93,6 +94,7 @@ def test_overfit_batch_limits(tmpdir):
|
|||
assert torch.eq(ya, yb).all()
|
||||
|
||||
trainer = Trainer(overfit_batches=0.11)
|
||||
trainer.data_connector.attach_dataloaders(model=model)
|
||||
trainer.reset_train_dataloader(model)
|
||||
# The dataloader should have been overwritten with a Sequential sampler.
|
||||
assert trainer.train_dataloader is not train_loader
|
||||
|
@ -111,7 +113,9 @@ def test_overfit_batch_limits(tmpdir):
|
|||
# ------------------------------------------------------
|
||||
# test overfit_batches as percent
|
||||
# ------------------------------------------------------
|
||||
loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(split, model=model)
|
||||
trainer = Trainer(overfit_batches=0.11)
|
||||
trainer.data_connector.attach_dataloaders(model)
|
||||
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
|
||||
assert loader_num_batches[0] == num_train_samples
|
||||
|
||||
# make sure we turned off shuffle for the user
|
||||
|
@ -125,23 +129,35 @@ def test_overfit_batch_limits(tmpdir):
|
|||
# ------------------------------------------------------
|
||||
# test overfit_batches as int
|
||||
# ------------------------------------------------------
|
||||
loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(split, model=model)
|
||||
trainer = Trainer(overfit_batches=1)
|
||||
trainer.data_connector.attach_dataloaders(model)
|
||||
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
|
||||
assert loader_num_batches[0] == 1
|
||||
loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(split, model=model)
|
||||
trainer = Trainer(overfit_batches=5)
|
||||
trainer.data_connector.attach_dataloaders(model)
|
||||
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
|
||||
assert loader_num_batches[0] == 5
|
||||
|
||||
# ------------------------------------------------------
|
||||
# test limit_xxx_batches as percent AND int
|
||||
# ------------------------------------------------------
|
||||
if split == RunningStage.VALIDATING:
|
||||
loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model)
|
||||
trainer = Trainer(limit_val_batches=0.1)
|
||||
trainer.data_connector.attach_dataloaders(model)
|
||||
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
|
||||
assert loader_num_batches[0] == int(0.1 * len(val_loader))
|
||||
|
||||
loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(split, model=model)
|
||||
trainer = Trainer(limit_val_batches=10)
|
||||
trainer.data_connector.attach_dataloaders(model)
|
||||
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
|
||||
assert loader_num_batches[0] == 10
|
||||
else:
|
||||
loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(split, model=model)
|
||||
trainer = Trainer(limit_test_batches=0.1)
|
||||
trainer.data_connector.attach_dataloaders(model)
|
||||
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
|
||||
assert loader_num_batches[0] == int(0.1 * len(test_loader))
|
||||
|
||||
loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(split, model=model)
|
||||
trainer = Trainer(limit_test_batches=10)
|
||||
trainer.data_connector.attach_dataloaders(model)
|
||||
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
|
||||
assert loader_num_batches[0] == 10
|
||||
|
|
|
@ -220,9 +220,12 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
|
|||
limit_train_batches=0.2,
|
||||
auto_scale_batch_size="power",
|
||||
)
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True))
|
||||
fit_options = dict(train_dataloaders=model.dataloader(train=True))
|
||||
|
||||
with pytest.raises(MisconfigurationException):
|
||||
with pytest.raises(
|
||||
MisconfigurationException,
|
||||
match="The batch scaling feature cannot be used with dataloaders passed directly",
|
||||
):
|
||||
trainer.tune(model, **fit_options)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue