diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index cda6b3c6d4..aa7d03babe 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -675,7 +675,7 @@ To use Sharded Training, you need to first install FairScale using the command b .. code-block:: python # train using Sharded DDP - trainer = Trainer(accelerator='ddp', plugins='ddp_sharded') + trainer = Trainer(plugins='ddp_sharded') Sharded Training can work across all DDP variants by adding the additional ``--plugins ddp_sharded`` flag. diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index fbcdb405a5..5afc98a08a 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -21,6 +21,7 @@ from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel @@ -85,6 +86,11 @@ class DDPShardedPlugin(DDPPlugin): @property def lightning_module(self) -> LightningModule: + if not _FAIRSCALE_AVAILABLE: # pragma: no cover + raise MisconfigurationException( + "`DDPShardedPlugin` requires `fairscale` to be installed." + " Install it by running `pip install fairscale`." + ) return unwrap_lightning_module_sharded(self._model) def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index a99d6ea481..9d8b63d5b9 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -21,6 +21,7 @@ from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNative from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel @@ -71,6 +72,11 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): @property def lightning_module(self) -> LightningModule: + if not _FAIRSCALE_AVAILABLE: # pragma: no cover + raise MisconfigurationException( + "`DDPSpawnShardedPlugin` requires `fairscale` to be installed." + " Install it by running `pip install fairscale`." + ) return unwrap_lightning_module_sharded(self._model) def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):