lightning/pytorch_lightning/plugins/sharded_plugin.py

103 lines
3.8 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union, Any
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, AMPType, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel
class DDPShardedPlugin(DDPPlugin):
def __init__(self, **kwargs):
self._check_fairscale()
super().__init__(**kwargs)
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
):
self._wrap_optimizers(model)
return LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers)
def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)
def on_before_forward(self, model: LightningModule, *args):
return model.transfer_batch_to_device(args, model.trainer.root_gpu)
def _check_fairscale(self):
if not _FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
'Sharded DDP Plugin requires Fairscale to be installed.'
)
@rank_zero_only
def _optim_state_dict(self, optimizer):
return optimizer.state_dict()
def _wrap_optimizers(self, model):
trainer = model.trainer
if trainer.testing is True:
return
self._reinit_with_fairscale_oss(trainer)
def _reinit_with_fairscale_oss(self, trainer):
optimizers = trainer.optimizers
for x, optimizer in enumerate(optimizers):
if is_lightning_optimizer(optimizer):
optimizer = optimizer._optimizer
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(
params=optimizer.param_groups,
optim=optim_class,
**optimizer.defaults
)
optimizers[x] = zero_optimizer
del optimizer
trainer.convert_to_lightning_optimizers()
def get_model_from_plugin(
self,
model: Union['LightningShardedDataParallel', LightningModule]
) -> LightningModule:
if isinstance(model, LightningShardedDataParallel):
return model.module
return model
def required_plugins(self, amp_backend: AMPType, trainer) -> list:
if amp_backend == AMPType.APEX:
raise MisconfigurationException(
'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision.'
)
if amp_backend == AMPType.NATIVE:
return [ShardedNativeAMPPlugin(trainer=trainer)]
return []
def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any):
pass
def on_after_manual_backward(self, model: 'LightningShardedDataParallel'):
pass