410 lines
16 KiB
Python
410 lines
16 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License
|
|
import os
|
|
from typing import Any, List, Optional
|
|
|
|
import torch
|
|
import torch.distributed as torch_distrib
|
|
from torch import nn
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
from pytorch_lightning import LightningModule
|
|
from pytorch_lightning import _logger as log
|
|
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
|
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
|
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
|
|
if _FAIRSCALE_PIPE_AVAILABLE:
|
|
import fairscale.nn.model_parallel as mpu
|
|
from fairscale.nn import PipeRPCWrapper
|
|
from fairscale.nn.pipe import balance as pipe_balance
|
|
from fairscale.nn.pipe import rpc as rpc_pipe
|
|
from fairscale.nn.pipe.pipeline import PipelineStyle
|
|
|
|
|
|
class DDPSequentialPlugin(RPCPlugin):
|
|
def __init__(
|
|
self,
|
|
balance: Optional[List[int]] = None,
|
|
microbatches: int = 8,
|
|
checkpoint: str = 'except_last',
|
|
balance_mode: str = "balance_by_size",
|
|
pipelined_backward: Optional[bool] = True,
|
|
**kwargs):
|
|
"""
|
|
Provides sequential model parallelism for :class:`nn.Sequential <torch.nn.Sequential>` module.
|
|
If the module requires lots of memory, Pipe can be used to reduce this by leveraging multiple GPUs.
|
|
|
|
Example::
|
|
class MyLightningModule:
|
|
def __init__(self):
|
|
...
|
|
model.sequential_module = torch.nn.Sequential(my_layers)
|
|
|
|
# Split my module across 4 gpus, one layer each
|
|
model = MyLightningModule()
|
|
plugin = DDPSequentialPlugin(balance=[1, 1, 1, 1])
|
|
trainer = Trainer(accelerator='ddp', gpus=4, plugins=[plugin])
|
|
trainer.fit(model)
|
|
|
|
.. _DDPSequentialPlugin: https://arxiv.org/abs/1811.06965
|
|
|
|
Pipeline parallelism comes with with checkpointing to reduce peak
|
|
memory required to train while minimizing device under-utilization.
|
|
This is turned on by default and can be turned off via the checkpoint argument.
|
|
|
|
You should determine the balance when defining the plugin,
|
|
or you can pass an example input array via the LightningModule to infer a balance.
|
|
The module will be partitioned into multiple devices according to the given balance. You may also rely on
|
|
your own heuristics to find your own optimal configuration.
|
|
|
|
Args:
|
|
balance: The balance of the model, i.e [2, 2] (two layers on each GPU).
|
|
If not provided assumes user provides an input example array to find a balance on all GPUs.
|
|
|
|
microbatches: Allows for parallelization to reduce device utilization
|
|
by splitting the batch into further smaller batches.
|
|
|
|
checkpoint: Enables gradient checkpointing. ['always', 'except_last', 'never']
|
|
|
|
balance_mode: Type of balance heuristic to use if balance to be inferred.
|
|
|
|
- 'balance_by_size': checks memory usage of each layer and determines balance
|
|
|
|
- 'balance_by_time': checks time of each layer and determines balance
|
|
|
|
pipelined_backward: if True, call torch.autograd.backward once per microbatch on the
|
|
|
|
backward pass (instead of once for the whole batch). This works
|
|
around a potential deadlock in pytorch when using tensor parallelism
|
|
at the same time. Defaults to `True` if
|
|
`get_model_parallel_world_size() > 1`
|
|
"""
|
|
self._check_pipe_available()
|
|
super().__init__(**kwargs)
|
|
|
|
self.balance = balance
|
|
|
|
self.microbatches = microbatches
|
|
self.checkpoint = checkpoint
|
|
self.balance_mode = balance_mode
|
|
self.pipelined_backward = pipelined_backward
|
|
self.main_rpc_process = False # Updated by main process, default for all secondary processes
|
|
|
|
def init_ddp_connection(
|
|
self,
|
|
trainer,
|
|
cluster_environment,
|
|
global_rank: int,
|
|
world_size: int,
|
|
is_slurm_managing_tasks: bool = True,
|
|
) -> None:
|
|
trainer.prepared_for_backwards = False
|
|
self._check_arguments(trainer)
|
|
if self._skip_init_connections(trainer):
|
|
return
|
|
super().init_ddp_connection(
|
|
trainer=trainer,
|
|
cluster_environment=cluster_environment,
|
|
global_rank=global_rank,
|
|
world_size=world_size,
|
|
is_slurm_managing_tasks=is_slurm_managing_tasks
|
|
)
|
|
super().init_rpc_connection(
|
|
global_rank=global_rank,
|
|
world_size=world_size
|
|
)
|
|
model = trainer.get_model()
|
|
self.gpus_per_model = self._infer_check_num_gpus(trainer)
|
|
self.init_model_parallel_groups(trainer)
|
|
self.set_main_rpc_process()
|
|
|
|
self._check_sequential_model_exists(model)
|
|
if self.main_rpc_process:
|
|
if self.balance is None:
|
|
self._infer_model_balance(trainer)
|
|
self._assert_valid_model_balance(trainer)
|
|
|
|
def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
|
|
pass
|
|
|
|
def _infer_model_balance(self, trainer):
|
|
log.info(f'Inferring model balance using {self.balance_mode} mode')
|
|
model = trainer.get_model()
|
|
if model.example_input_array is None:
|
|
raise MisconfigurationException(
|
|
'Please set example_input_array to your model, so we can infer the right model balance for you')
|
|
balance_func = getattr(pipe_balance, self.balance_mode)
|
|
self.balance = balance_func(self.gpus_per_model, model.sequential_module, model.example_input_array)
|
|
self._sync_balance_to_all_parallel_groups()
|
|
|
|
log.info(f'The following model balance {self.balance.tolist()} was inferred using {self.balance_mode} mode')
|
|
|
|
def _sync_balance_to_all_parallel_groups(self, main_rank=0):
|
|
"""
|
|
Ensures that we sync the balance to all main processes, so that the balance is the same per replica.
|
|
Args:
|
|
main_rank: The rank with the balance we'd like to replicate.
|
|
"""
|
|
self.balance = torch.tensor(self.balance, dtype=torch.int, device='cuda')
|
|
# Ensure we sync to all processes within the main data parallel group
|
|
# We use the data parallel group as all main processes are found within the same group
|
|
torch_distrib.broadcast(self.balance, src=main_rank, group=mpu.get_data_parallel_group())
|
|
self.balance = self.balance.cpu()
|
|
|
|
def _check_sequential_model_exists(self, model):
|
|
if not hasattr(model, "sequential_module") or not isinstance(model.sequential_module, nn.Sequential):
|
|
raise MisconfigurationException(
|
|
'Could not find a PipeLightningModule within the model. '
|
|
'Did you set your sequential model as the `sequential_module` attribute of your model?')
|
|
|
|
def _find_and_init_pipe_module(self, model):
|
|
if hasattr(model, "sequential_module") and isinstance(model.sequential_module, LightningPipeModule):
|
|
# model has been wrapped already
|
|
return
|
|
elif hasattr(model, "sequential_module") and isinstance(model.sequential_module, nn.Sequential):
|
|
# try to wrap model for the user
|
|
model.sequential_module = LightningPipeModule(
|
|
model.sequential_module,
|
|
balance=self.balance,
|
|
microbatches=self.microbatches,
|
|
checkpoint=self.checkpoint,
|
|
)
|
|
# Update references for workers to access correct lightning functions when calling RPC
|
|
model.sequential_module.trainer = model.trainer
|
|
model.sequential_module.configure_optimizers = model.configure_optimizers
|
|
|
|
# Update references for main process to access correct lightning functions when calling RPC
|
|
model.sequential_module.module.model.trainer = model.trainer
|
|
model.sequential_module.module.model.configure_optimizers = model.configure_optimizers
|
|
|
|
else:
|
|
raise MisconfigurationException(
|
|
'Could not find a PipeLightningModule within the model. '
|
|
'Did you defined set your sequential model as an `sequential_module` attribute of your model ?'
|
|
)
|
|
|
|
def _assert_valid_model_balance(self, trainer):
|
|
model = trainer.get_model()
|
|
if sum(self.balance) != len(model.sequential_module):
|
|
raise MisconfigurationException(
|
|
f'The provided balance sum: {sum(self.balance)} does not'
|
|
f' match your Sequential length: {len(model.sequential_module)}')
|
|
|
|
def _skip_init_connections(self, trainer):
|
|
"""
|
|
Skip initialization if torch is already initialized and we're in testing.
|
|
Returns: Whether to skip initialization
|
|
|
|
"""
|
|
return torch_distrib.is_initialized() and trainer.testing
|
|
|
|
def init_model_parallel_groups(self, trainer):
|
|
num_model_parallel = 1 # TODO currently no support for vertical model parallel
|
|
mpu.initialize_model_parallel(
|
|
model_parallel_size_=num_model_parallel,
|
|
pipeline_length=self.gpus_per_model
|
|
)
|
|
|
|
def _infer_check_num_gpus(self, trainer):
|
|
"""
|
|
Infer the number of GPUs per model.
|
|
|
|
Args:
|
|
trainer: The trainer object.
|
|
|
|
Returns: The appropriate balance for the model
|
|
"""
|
|
if isinstance(self.balance, list):
|
|
if len(self.balance) != (trainer.world_size / trainer.num_nodes):
|
|
raise MisconfigurationException(
|
|
"Pipe currently only supports splitting the module onto all available GPUs"
|
|
)
|
|
# User has defined a balance for his model
|
|
return len(self.balance)
|
|
# Assume that the user wants to balance his model on all GPUs
|
|
return trainer.world_size
|
|
|
|
def on_accelerator_exit_rpc_process(self, trainer) -> None:
|
|
if not trainer.testing:
|
|
torch_distrib.barrier() # Ensure we await main process initialization
|
|
|
|
# Add trainer/configure_optimizers to the pipe model for access in all worker processes
|
|
rpc_pipe.PipeModel.trainer = trainer
|
|
del rpc_pipe.PipeModel.trainer.model.sequential_module
|
|
rpc_pipe.PipeModel.trainer.model.sequential_module = rpc_pipe.PipeModel
|
|
rpc_pipe.PipeModel.configure_optimizers = trainer.model.configure_optimizers
|
|
super().on_accelerator_exit_rpc_process(trainer)
|
|
|
|
def set_main_rpc_process(self):
|
|
self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0
|
|
|
|
def on_main_rpc_connection(self, trainer) -> None:
|
|
# Create pipe_module
|
|
model = trainer.get_model()
|
|
self._find_and_init_pipe_module(model)
|
|
if not trainer.testing:
|
|
torch_distrib.barrier() # Ensure we join main process initialization
|
|
model.sequential_module.foreach_worker(register_optimizers, include_self=True)
|
|
|
|
def _check_arguments(self, trainer):
|
|
if trainer.amp_backend is not None:
|
|
raise MisconfigurationException(
|
|
'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision')
|
|
|
|
def configure_ddp(
|
|
self,
|
|
model: LightningModule, device_ids: List[int]) -> DistributedDataParallel:
|
|
ddp_plugin = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
|
|
# Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
|
|
ddp_plugin.PREPARE_FOR_BACKWARDS = False
|
|
return ddp_plugin
|
|
|
|
@rank_zero_only
|
|
def rpc_save_model(
|
|
self,
|
|
save_model_fn,
|
|
last_filepath,
|
|
trainer,
|
|
pl_module) -> None:
|
|
model = trainer.get_model()
|
|
if not hasattr(model.sequential_module, "foreach_worker"):
|
|
return
|
|
current_layers = pl_module.sequential_module
|
|
model.sequential_module.foreach_worker(
|
|
save_layers_on_all_rank_zero_workers,
|
|
{"gpus_per_model": self.gpus_per_model},
|
|
include_self=True
|
|
)
|
|
pl_module.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model)
|
|
save_model_fn(last_filepath, trainer, pl_module)
|
|
pl_module.sequential_module = current_layers
|
|
|
|
def worker_optimizer_step(
|
|
self,
|
|
model: LightningModule,
|
|
opt_idx: int,
|
|
*args,
|
|
**kwargs) -> None:
|
|
model.sequential_module.foreach_worker(
|
|
run_optimizer,
|
|
{"opt_idx": opt_idx, "args": args, "kwargs": kwargs},
|
|
include_self=False
|
|
)
|
|
|
|
def distributed_sampler_kwargs(self, distributed_sampler_kwargs):
|
|
return dict(
|
|
num_replicas=mpu.get_data_parallel_world_size(),
|
|
rank=mpu.get_data_parallel_rank(),
|
|
)
|
|
|
|
@property
|
|
def data_parallel_group(self):
|
|
return mpu.get_data_parallel_group()
|
|
|
|
@property
|
|
def is_main_rpc_process(self) -> bool:
|
|
return self.main_rpc_process
|
|
|
|
@property
|
|
def return_after_exit_rpc_process(self) -> bool:
|
|
return True
|
|
|
|
def barrier(self, name: Optional[str] = None) -> None:
|
|
if torch_distrib.is_initialized() and self.is_main_rpc_process:
|
|
torch_distrib.barrier(group=self.data_parallel_group)
|
|
|
|
def _check_pipe_available(self):
|
|
if not _FAIRSCALE_PIPE_AVAILABLE:
|
|
raise MisconfigurationException(
|
|
'PipeRPCPlugin requires FairScale and currently is only supported on PyTorch 1.6.'
|
|
)
|
|
|
|
|
|
class LightningPipeModule(nn.Module):
|
|
"""
|
|
This class wraps Fairscale Pipe and PipeRCPWrapper class.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
module: nn.Sequential,
|
|
balance: List[int],
|
|
microbatches: int = 8,
|
|
checkpoint='never'):
|
|
super().__init__()
|
|
self.module = module
|
|
self.balance = balance
|
|
self.microbatches = microbatches
|
|
self.checkpoint = checkpoint
|
|
self._init_pipe()
|
|
|
|
def _init_pipe(self):
|
|
device = torch.device("cuda", torch_distrib.get_rank())
|
|
|
|
self.module = PipeRPCWrapper(
|
|
module=self.module,
|
|
balance=self.balance,
|
|
chunks=self.microbatches,
|
|
style=PipelineStyle.MultiProcess,
|
|
input_device=device,
|
|
worker_map=self.get_worker_map(),
|
|
checkpoint=self.checkpoint,
|
|
)
|
|
|
|
def foreach_worker(self, *args, **kwargs):
|
|
self.module.foreach_worker(*args, **kwargs)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.module(*args, **kwargs)
|
|
|
|
def get_worker_map(self):
|
|
# TODO, is this correct with multinodes? We also assume "worker" is the same as defined in the RPCPlugin
|
|
return {rank: f"worker{rank}" for rank in range(torch_distrib.get_world_size())}
|
|
|
|
|
|
def register_optimizers(ctx, model):
|
|
optimizers, lr_schedulers, optimizer_frequencies = model.trainer.init_optimizers(model)
|
|
model.trainer.optimizers = optimizers
|
|
model.trainer.lr_schedulers = lr_schedulers
|
|
model.trainer.optimizer_frequencies = optimizer_frequencies
|
|
model.trainer.convert_to_lightning_optimizers()
|
|
|
|
|
|
def run_optimizer(ctx, model):
|
|
trainer = model.trainer
|
|
opt_idx = ctx["opt_idx"]
|
|
optimizer = trainer.optimizers[opt_idx]
|
|
optimizer.step(*ctx["args"], **ctx["kwargs"])
|
|
|
|
|
|
def save_layers_on_all_rank_zero_workers(ctx, model):
|
|
gpus_per_model = ctx["gpus_per_model"]
|
|
rank = torch_distrib.get_rank()
|
|
if rank in range(gpus_per_model):
|
|
seq = list(model.children())[0]
|
|
torch.save(seq, f"seq_{rank}.pt")
|
|
|
|
|
|
def load_sequential_from_saved_layers(gpus_per_model):
|
|
partial_seqs = [torch.load(f"seq_{rank}.pt", map_location='cpu') for rank in range(gpus_per_model)]
|
|
seq = nn.Sequential()
|
|
for p_seq in partial_seqs:
|
|
for name, child in p_seq.named_children():
|
|
seq.add_module(name, child)
|
|
# delete tmp files
|
|
[os.remove(f"seq_{rank}.pt") for rank in range(gpus_per_model)]
|
|
return seq
|