remove deprecated test (#3820)

This commit is contained in:
William Falcon 2020-10-03 13:21:10 -04:00 committed by GitHub
parent d9bc95f83e
commit 0fb8c54fda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 0 additions and 57 deletions

View File

@ -119,63 +119,6 @@ def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):
assert len(predictions) == len(dm.mnist_test)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_result_obj_predictions_ddp_spawn(tmpdir):
seed_everything(4321)
distributed_backend = 'ddp_spawn'
option = 0
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
dm = TrialMNISTDataModule(tmpdir)
prediction_file = Path(tmpdir) / 'predictions.pt'
model = EvalModelTemplate(learning_rate=0.005)
model.test_option = option
model.prediction_file = prediction_file.as_posix()
model.test_step = model.test_step_result_preds
model.test_step_end = None
model.test_epoch_end = None
model.test_end = None
prediction_files = [Path(tmpdir) / 'predictions_rank_0.pt', Path(tmpdir) / 'predictions_rank_1.pt']
for prediction_file in prediction_files:
if prediction_file.exists():
prediction_file.unlink()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
weights_summary=None,
deterministic=True,
distributed_backend=distributed_backend,
gpus=[0, 1]
)
# Prediction file shouldn't exist yet because we haven't done anything
# assert not model.prediction_file.exists()
result = trainer.fit(model, dm)
assert result == 1
result = trainer.test(datamodule=dm)
result = result[0]
assert result['test_loss'] < 0.6
assert result['test_acc'] > 0.8
dm.setup('test')
# check prediction file now exists and is of expected length
size = 0
for prediction_file in prediction_files:
assert prediction_file.exists()
predictions = torch.load(prediction_file)
size += len(predictions)
assert size == len(dm.mnist_test)
def test_result_gather_stack():
""" Test that tensors get concatenated when they all have the same shape. """
outputs = [