Add base code

This commit is contained in:
SeanNaren 2020-11-19 10:21:34 +00:00
parent b506a7e46a
commit 2e8585f46a
3 changed files with 168 additions and 0 deletions

View File

@ -0,0 +1,47 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Union
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
from fairscale.optim import OSS
from torch import nn
class LightningShardedDataParallel(ShardedDataParallel):
def __init__(
self,
base_model: nn.Module,
sharded_optimizer: Union[OSS, List[OSS]],
process_group: Any = None,
broadcast_buffers: bool = True
):
super().__init__(
base_model=base_model,
sharded_optimizer=sharded_optimizer,
process_group=process_group,
broadcast_buffers=broadcast_buffers
)
self.module = base_model
def forward(self, *inputs, **kwargs):
if self.enable_broadcast_buffers:
self.sync_buffers()
if self.base_model.training:
outputs = self.base_model.training_step(*inputs, **kwargs)
elif self.base_model.testing:
outputs = self.base_model.test_step(*inputs, **kwargs)
else:
outputs = self.base_model.validation_step(*inputs, **kwargs)
return outputs

View File

@ -0,0 +1,31 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
class ShardedNativeAMPPlugin(NativeAMPPlugin):
@property
def scaler(self):
return ShardedGradScaler()
def clip_gradients(self, grad_clip_val, model, optimizer):
max_norm = grad_clip_val
norm_type = float(2.0)
optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(max_norm, norm_type=norm_type)

View File

@ -0,0 +1,90 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Any, Optional
from fairscale.optim import OSS
from pytorch_lightning import LightningModule
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import rank_zero_only
class DDPShardedPlugin(DDPPlugin):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
):
self._wrap_optimizers(model)
if model.trainer.testing: # Revert to standard DDP if using testing
super().configure_ddp(
model=model,
device_ids=device_ids
)
else:
model = LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers)
return model
def optimizer_state(self, optimizer: OSS) -> Optional[dict]:
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)
def on_before_forward(self, args: Any, model: LightningModule):
batch = args[0]
batch = model.transfer_batch_to_device(batch, model.trainer.root_gpu)
args[0] = batch
return args
@rank_zero_only
def _optim_state_dict(self, optimizer):
"""
Ensure we only return the state dict from the optimizer on rank 0.
Other ranks do not have the complete optimizer state.
Args:
optimizer: OSS Optimizer
Returns:
State dict if rank 0 else None.
"""
return optimizer.state_dict()
def _wrap_optimizers(self, model):
trainer = model.trainer
if trainer.testing is True:
return
self._reinit_with_fairscale_oss(trainer)
def _reinit_with_fairscale_oss(self, trainer):
"""
Re-initialise optimizers to use OSS wrapper. We need to re-initialise due to
the parameters being sharded across distributed processes, each optimizing a partition.
Args:
trainer: trainer object to reinit optimizers.
"""
optimizers = trainer.optimizers
lr_schedulers = trainer.lr_schedulers
for x, optimizer in enumerate(optimizers):
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(
params=optimizer.param_groups,
optim=optim_class,
**optimizer.defaults
)
optimizers[x] = zero_optimizer
for scheduler in lr_schedulers:
scheduler = scheduler['scheduler']
if scheduler.optimizer == optimizer:
scheduler.optimizer = zero_optimizer
del optimizer