Add fairscale install msg for Sharded Plugins (#7213)
This commit is contained in:
parent
52a5cee0a7
commit
5cf9afa176
|
@ -675,7 +675,7 @@ To use Sharded Training, you need to first install FairScale using the command b
|
|||
.. code-block:: python
|
||||
|
||||
# train using Sharded DDP
|
||||
trainer = Trainer(accelerator='ddp', plugins='ddp_sharded')
|
||||
trainer = Trainer(plugins='ddp_sharded')
|
||||
|
||||
Sharded Training can work across all DDP variants by adding the additional ``--plugins ddp_sharded`` flag.
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from pytorch_lightning.core.optimizer import is_lightning_optimizer
|
|||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
|
@ -85,6 +86,11 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
|
||||
@property
|
||||
def lightning_module(self) -> LightningModule:
|
||||
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
|
||||
raise MisconfigurationException(
|
||||
"`DDPShardedPlugin` requires `fairscale` to be installed."
|
||||
" Install it by running `pip install fairscale`."
|
||||
)
|
||||
return unwrap_lightning_module_sharded(self._model)
|
||||
|
||||
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
|
||||
|
|
|
@ -21,6 +21,7 @@ from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNative
|
|||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
|
@ -71,6 +72,11 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
|
||||
@property
|
||||
def lightning_module(self) -> LightningModule:
|
||||
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
|
||||
raise MisconfigurationException(
|
||||
"`DDPSpawnShardedPlugin` requires `fairscale` to be installed."
|
||||
" Install it by running `pip install fairscale`."
|
||||
)
|
||||
return unwrap_lightning_module_sharded(self._model)
|
||||
|
||||
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
|
||||
|
|
Loading…
Reference in New Issue