remove deprecated test (#3820)
This commit is contained in:
parent
d9bc95f83e
commit
0fb8c54fda
|
@ -119,63 +119,6 @@ def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):
|
|||
assert len(predictions) == len(dm.mnist_test)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_result_obj_predictions_ddp_spawn(tmpdir):
|
||||
seed_everything(4321)
|
||||
|
||||
distributed_backend = 'ddp_spawn'
|
||||
option = 0
|
||||
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||
|
||||
dm = TrialMNISTDataModule(tmpdir)
|
||||
|
||||
prediction_file = Path(tmpdir) / 'predictions.pt'
|
||||
|
||||
model = EvalModelTemplate(learning_rate=0.005)
|
||||
model.test_option = option
|
||||
model.prediction_file = prediction_file.as_posix()
|
||||
model.test_step = model.test_step_result_preds
|
||||
model.test_step_end = None
|
||||
model.test_epoch_end = None
|
||||
model.test_end = None
|
||||
|
||||
prediction_files = [Path(tmpdir) / 'predictions_rank_0.pt', Path(tmpdir) / 'predictions_rank_1.pt']
|
||||
for prediction_file in prediction_files:
|
||||
if prediction_file.exists():
|
||||
prediction_file.unlink()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=3,
|
||||
weights_summary=None,
|
||||
deterministic=True,
|
||||
distributed_backend=distributed_backend,
|
||||
gpus=[0, 1]
|
||||
)
|
||||
|
||||
# Prediction file shouldn't exist yet because we haven't done anything
|
||||
# assert not model.prediction_file.exists()
|
||||
|
||||
result = trainer.fit(model, dm)
|
||||
assert result == 1
|
||||
result = trainer.test(datamodule=dm)
|
||||
result = result[0]
|
||||
assert result['test_loss'] < 0.6
|
||||
assert result['test_acc'] > 0.8
|
||||
|
||||
dm.setup('test')
|
||||
|
||||
# check prediction file now exists and is of expected length
|
||||
size = 0
|
||||
for prediction_file in prediction_files:
|
||||
assert prediction_file.exists()
|
||||
predictions = torch.load(prediction_file)
|
||||
size += len(predictions)
|
||||
assert size == len(dm.mnist_test)
|
||||
|
||||
|
||||
def test_result_gather_stack():
|
||||
""" Test that tensors get concatenated when they all have the same shape. """
|
||||
outputs = [
|
||||
|
|
Loading…
Reference in New Issue