diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3dbb7b7c07..d6641c2f7a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -988,3 +988,17 @@ def test_trainer_setup_call(tmpdir): trainer.test(ckpt_path=None) assert trainer.stage == 'test' assert trainer.get_model().stage == 'test' + + +def test_trainer_ddp_spawn_none_checkpoint(tmpdir): + model = EvalModelTemplate() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + checkpoint_callback=None, + distributed_backend="ddp_spawn" + ) + assert trainer.checkpoint_callback is None + result = trainer.fit(model) + assert trainer.checkpoint_callback is None + assert result == 1