diff --git a/tests/test_models.py b/tests/test_models.py index ea0741da41..c20c927d6a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -57,7 +57,7 @@ def test_amp_gpu_ddp_slurm_managed(): # simulate setting slurm flags os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0]) - os.environ['SLURM_LOCALID'] = 0 + os.environ['SLURM_LOCALID'] = str(0) hparams = get_hparams() model = LightningTestModel(hparams)