Add fairscale install msg for Sharded Plugins (#7213)

This commit is contained in:
Kaushik B 2021-04-27 13:52:44 +05:30 committed by GitHub
parent 52a5cee0a7
commit 5cf9afa176
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 1 deletions

View File

@ -675,7 +675,7 @@ To use Sharded Training, you need to first install FairScale using the command b
.. code-block:: python
# train using Sharded DDP
trainer = Trainer(accelerator='ddp', plugins='ddp_sharded')
trainer = Trainer(plugins='ddp_sharded')
Sharded Training can work across all DDP variants by adding the additional ``--plugins ddp_sharded`` flag.

View File

@ -21,6 +21,7 @@ from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
@ -85,6 +86,11 @@ class DDPShardedPlugin(DDPPlugin):
@property
def lightning_module(self) -> LightningModule:
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPShardedPlugin` requires `fairscale` to be installed."
" Install it by running `pip install fairscale`."
)
return unwrap_lightning_module_sharded(self._model)
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):

View File

@ -21,6 +21,7 @@ from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNative
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
@ -71,6 +72,11 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
@property
def lightning_module(self) -> LightningModule:
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPSpawnShardedPlugin` requires `fairscale` to be installed."
" Install it by running `pip install fairscale`."
)
return unwrap_lightning_module_sharded(self._model)
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):