diff --git a/src/lightning/app/components/multi_node/trainer.py b/src/lightning/app/components/multi_node/trainer.py index aaa52fad34..047b781cb9 100644 --- a/src/lightning/app/components/multi_node/trainer.py +++ b/src/lightning/app/components/multi_node/trainer.py @@ -52,7 +52,7 @@ class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor): try: pkg = importlib.import_module(pkg_name) trainers.append(pkg.Trainer) - strategies.append(pkg.strategies.DDPSpawnStrategy) + strategies.append(pkg.strategies.DDPStrategy) mps_accelerators.append(pkg.accelerators.MPSAccelerator) except (ImportError, ModuleNotFoundError): continue diff --git a/tests/tests_app/components/multi_node/test_trainer.py b/tests/tests_app/components/multi_node/test_trainer.py index a3bb2ea91b..55282f82f7 100644 --- a/tests/tests_app/components/multi_node/test_trainer.py +++ b/tests/tests_app/components/multi_node/test_trainer.py @@ -93,4 +93,4 @@ def test_trainer_run_executor_arguments_choices( @pytest.mark.skipif(not module_available("lightning"), reason="lightning not available") def test_trainer_run_executor_invalid_strategy_instances(): with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): - _, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnStrategy()) + _, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPStrategy(start_method="spawn"))