[feat] pp 2/n (#5026)

* Added changes for RPC plugin

* Add missing kwargs

* Fix code format

* Loading refactors by introducing is_distributed var, fix optimizer step flow

* Add rpc guard

* Added docstrings and typing

* resolve comments

* Add additional rpc hook, refactor name of exit process hook for clarity

* remove annotation

* Modify behaviour to allow optional return, add test for rpc plugin

* resolve tests

* rename is_ddp_based

* update

* update for windows

* update

* resolve test

* code smell

* Added sequential plugin

* resolve bug

* update

* cleanup

* add Exception

* resolve docs

* Remove ddp support

* Revert distributed -> ddp

* Update pl_examples/basic_examples/conv_sequential_example.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pl_examples/basic_examples/conv_sequential_example.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Address code review points

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Add missing return

* Fix formatting, add datamodule args

* add small comment

* resolve comments

* resolve comments

* update source for fairscale

* update extras

* remove staticmethod

* resolve flake8

* Skip tests that are failing due to bug upstream with multiple optimizers and shard

* update

* update on comments

* clean test

* latest comments

* remove old comments

* add todo

* Update version

* update

* resolve bugs

* resolve bugs

* update test

* remove hanging test

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* resolve on comments

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* resolve on comments

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* remove ImportError

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
chaton 2020-12-09 12:56:51 +00:00 committed by GitHub
parent 7d9784e951
commit ef8ef12fd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 881 additions and 27 deletions

View File

@ -32,5 +32,6 @@ repos:
types: [python] types: [python]
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: master
hooks: hooks:
- id: mypy - id: mypy

View File

@ -131,6 +131,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
) )
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows", @pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows") reason="Distributed training is not supported on Windows")
@ -148,6 +149,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
) )
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows", @pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows") reason="Distributed training is not supported on Windows")
@ -189,7 +191,7 @@ class SeedTrainLoaderManualModel(SeedTrainLoaderModel):
# ensure we forward the correct params to the optimizer # ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes # without retain_graph we can't do multiple backward passes
self.manual_backward(loss_2, opt_b, retain_graph=True) self.manual_backward(loss_2, opt_b)
# todo: understand why synchronization breaks there. # todo: understand why synchronization breaks there.
# self.manual_backward(loss_2, opt_a, retain_graph=True) # self.manual_backward(loss_2, opt_a, retain_graph=True)
opt_b.step() opt_b.step()

View File

@ -0,0 +1,216 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example script of running the experimental DDP Sequential Plugin.
This script splits a convolutional model onto multiple GPUs, whilst using the internal built in balancer
to balance across your GPUs.
To run:
python conv_model_sequential_example.py --accelerator ddp --gpus 4 --max_epochs 1 --batch_size 256 --use_ddp_sequential
"""
import math
from argparse import ArgumentParser
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
from pytorch_lightning.utilities import BOLTS_AVAILABLE, FAIRSCALE_PIPE_AVAILABLE
if BOLTS_AVAILABLE:
import pl_bolts
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
#####################
# Modules #
#####################
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
###############################
# LightningModule #
###############################
class LitResnet(pl.LightningModule):
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
super().__init__()
self.save_hyperparameters()
self.sequential_module = nn.Sequential(
# Conv Layer block 1
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),
# Conv Layer block 2
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Dropout2d(p=0.05),
# Conv Layer block 3
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),
Flatten(),
nn.Dropout(p=0.1),
nn.Linear(4096, 1024),
nn.ReLU(inplace=False),
nn.Linear(1024, 512),
nn.ReLU(inplace=False),
nn.Dropout(p=0.1),
nn.Linear(512, 10)
)
self._example_input_array = torch.randn((1, 3, 32, 32))
self._manual_optimization = manual_optimization
if self._manual_optimization:
self.training_step = self.training_step_manual
def forward(self, x):
out = self.sequential_module(x)
return F.log_softmax(out, dim=-1)
def training_step_manual(self, batch, batch_idx):
opt = self.optimizers()
def closure():
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
self.manual_backward(loss, opt)
self.log('train_loss', loss, prog_bar=True)
opt.step(closure=closure)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
self.log('Training Loss', loss)
return loss
def _evaluate(self, batch, batch_idx, stage=None):
x, y = batch
out = self.forward(x)
logits = F.log_softmax(out, dim=-1)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=-1)
acc = accuracy(preds, y)
if stage:
self.log(f'{stage}_loss', loss, prog_bar=True)
self.log(f'{stage}_acc', acc, prog_bar=True)
return loss, acc
def validation_step(self, batch, batch_idx):
return self._evaluate(batch, batch_idx, 'val')[0]
def test_step(self, batch, batch_idx):
loss, acc = self._evaluate(batch, batch_idx, 'test')
self.log_dict({'test_loss': loss, 'test_acc': acc})
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': torch.optim.lr_scheduler.OneCycleLR(
optimizer,
0.1,
epochs=self.trainer.max_epochs,
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)),
'interval': 'step',
}
}
@property
def automatic_optimization(self) -> bool:
return not self._manual_optimization
#################################
# Instantiate Data Module #
#################################
def instantiate_datamodule(args):
train_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
cifar10_normalization(),
])
test_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
cifar10_normalization(),
])
cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule(
batch_size=args.batch_size,
train_transforms=train_transforms,
test_transforms=test_transforms,
val_transforms=test_transforms,
)
return cifar10_dm
if __name__ == "__main__":
parser = ArgumentParser(description="Pipe Example")
parser.add_argument("--use_ddp_sequential", action="store_true")
parser = Trainer.add_argparse_args(parser)
parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser)
args = parser.parse_args()
assert BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
assert FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."
cifar10_dm = instantiate_datamodule(args)
plugins = None
if args.use_ddp_sequential:
plugins = DDPSequentialPlugin()
model = LitResnet(batch_size=args.batch_size, manual_optimization=not args.automatic_optimization)
trainer = pl.Trainer.from_argparse_args(args, plugins=[plugins] if plugins else None)
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm)
if trainer.accelerator_backend.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()

View File

@ -155,6 +155,7 @@ class LightningDistributedDataParallel(DistributedDataParallel):
""" """
Override the forward call in lightning so it goes to training and validation step respectively Override the forward call in lightning so it goes to training and validation step respectively
""" """
PREPARE_FOR_BACKWARDS = True
def parallel_apply(self, replicas, inputs, kwargs): def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
@ -165,6 +166,7 @@ class LightningDistributedDataParallel(DistributedDataParallel):
fx_called: str = '' fx_called: str = ''
if self.device_ids: if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1: if len(self.device_ids) == 1:
# -------------- # --------------
@ -195,7 +197,7 @@ class LightningDistributedDataParallel(DistributedDataParallel):
else: else:
output = self.module.validation_step(*inputs, **kwargs) output = self.module.validation_step(*inputs, **kwargs)
if not self._reducer_prepared_for_backwards: if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS:
self.reducer_prepare_for_backwards(output) self.reducer_prepare_for_backwards(output)
if output is None: if output is None:

View File

@ -0,0 +1,409 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os
from typing import Any, List, Optional
import torch
import torch.distributed as torch_distrib
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if FAIRSCALE_PIPE_AVAILABLE:
import fairscale.nn.model_parallel as mpu
from fairscale.nn import PipeRPCWrapper
from fairscale.nn.pipe import balance as pipe_balance
from fairscale.nn.pipe import rpc as rpc_pipe
from fairscale.nn.pipe.pipeline import PipelineStyle
class DDPSequentialPlugin(RPCPlugin):
def __init__(
self,
balance: Optional[List[int]] = None,
microbatches: int = 8,
checkpoint: str = 'except_last',
balance_mode: str = "balance_by_size",
pipelined_backward: Optional[bool] = True,
**kwargs):
"""
Provides sequential model parallelism for :class:`nn.Sequential <torch.nn.Sequential>` module.
If the module requires lots of memory, Pipe can be used to reduce this by leveraging multiple GPUs.
Example::
class MyLightningModule:
def __init__(self):
...
model.sequential_module = torch.nn.Sequential(my_layers)
# Split my module across 4 gpus, one layer each
model = MyLightningModule()
plugin = DDPSequentialPlugin(balance=[1, 1, 1, 1])
trainer = Trainer(accelerator='ddp', gpus=4, plugins=[plugin])
trainer.fit(model)
.. _DDPSequentialPlugin: https://arxiv.org/abs/1811.06965
Pipeline parallelism comes with with checkpointing to reduce peak
memory required to train while minimizing device under-utilization.
This is turned on by default and can be turned off via the checkpoint argument.
You should determine the balance when defining the plugin,
or you can pass an example input array via the LightningModule to infer a balance.
The module will be partitioned into multiple devices according to the given balance. You may also rely on
your own heuristics to find your own optimal configuration.
Args:
balance: The balance of the model, i.e [2, 2] (two layers on each GPU).
If not provided assumes user provides an input example array to find a balance on all GPUs.
microbatches: Allows for parallelization to reduce device utilization
by splitting the batch into further smaller batches.
checkpoint: Enables gradient checkpointing. ['always', 'except_last', 'never']
balance_mode: Type of balance heuristic to use if balance to be inferred.
- 'balance_by_size': checks memory usage of each layer and determines balance
- 'balance_by_time': checks time of each layer and determines balance
pipelined_backward: if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
"""
self._check_pipe_available()
super().__init__(**kwargs)
self.balance = balance
self.microbatches = microbatches
self.checkpoint = checkpoint
self.balance_mode = balance_mode
self.pipelined_backward = pipelined_backward
self.main_rpc_process = False # Updated by main process, default for all secondary processes
def init_ddp_connection(
self,
trainer,
cluster_environment,
global_rank: int,
world_size: int,
is_slurm_managing_tasks: bool = True,
) -> None:
trainer.prepared_for_backwards = False
self._check_arguments(trainer)
if self._skip_init_connections(trainer):
return
super().init_ddp_connection(
trainer=trainer,
cluster_environment=cluster_environment,
global_rank=global_rank,
world_size=world_size,
is_slurm_managing_tasks=is_slurm_managing_tasks
)
super().init_rpc_connection(
global_rank=global_rank,
world_size=world_size
)
model = trainer.get_model()
self.gpus_per_model = self._infer_check_num_gpus(trainer)
self.init_model_parallel_groups(trainer)
self.set_main_rpc_process()
self._check_sequential_model_exists(model)
if self.main_rpc_process:
if self.balance is None:
self._infer_model_balance(trainer)
self._assert_valid_model_balance(trainer)
def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
pass
def _infer_model_balance(self, trainer):
log.info(f'Inferring model balance using {self.balance_mode} mode')
model = trainer.get_model()
if model.example_input_array is None:
raise MisconfigurationException(
'Please set example_input_array to your model, so we can infer the right model balance for you')
balance_func = getattr(pipe_balance, self.balance_mode)
self.balance = balance_func(self.gpus_per_model, model.sequential_module, model.example_input_array)
self._sync_balance_to_all_parallel_groups()
log.info(f'The following model balance {self.balance.tolist()} was inferred using {self.balance_mode} mode')
def _sync_balance_to_all_parallel_groups(self, main_rank=0):
"""
Ensures that we sync the balance to all main processes, so that the balance is the same per replica.
Args:
main_rank: The rank with the balance we'd like to replicate.
"""
self.balance = torch.tensor(self.balance, dtype=torch.int, device='cuda')
# Ensure we sync to all processes within the main data parallel group
# We use the data parallel group as all main processes are found within the same group
torch_distrib.broadcast(self.balance, src=main_rank, group=mpu.get_data_parallel_group())
self.balance = self.balance.cpu()
def _check_sequential_model_exists(self, model):
if not hasattr(model, "sequential_module") or not isinstance(model.sequential_module, nn.Sequential):
raise MisconfigurationException(
'Could not find a PipeLightningModule within the model. '
'Did you set your sequential model as the `sequential_module` attribute of your model?')
def _find_and_init_pipe_module(self, model):
if hasattr(model, "sequential_module") and isinstance(model.sequential_module, LightningPipeModule):
# model has been wrapped already
return
elif hasattr(model, "sequential_module") and isinstance(model.sequential_module, nn.Sequential):
# try to wrap model for the user
model.sequential_module = LightningPipeModule(
model.sequential_module,
balance=self.balance,
microbatches=self.microbatches,
checkpoint=self.checkpoint,
)
# Update references for workers to access correct lightning functions when calling RPC
model.sequential_module.trainer = model.trainer
model.sequential_module.configure_optimizers = model.configure_optimizers
# Update references for main process to access correct lightning functions when calling RPC
model.sequential_module.module.model.trainer = model.trainer
model.sequential_module.module.model.configure_optimizers = model.configure_optimizers
else:
raise MisconfigurationException(
'Could not find a PipeLightningModule within the model. '
'Did you defined set your sequential model as an `sequential_module` attribute of your model ?'
)
def _assert_valid_model_balance(self, trainer):
model = trainer.get_model()
if sum(self.balance) != len(model.sequential_module):
raise MisconfigurationException(
f'The provided balance sum: {sum(self.balance)} does not'
f' match your Sequential length: {len(model.sequential_module)}')
def _skip_init_connections(self, trainer):
"""
Skip initialization if torch is already initialized and we're in testing.
Returns: Whether to skip initialization
"""
return torch_distrib.is_initialized() and trainer.testing
def init_model_parallel_groups(self, trainer):
num_model_parallel = 1 # TODO currently no support for vertical model parallel
mpu.initialize_model_parallel(
model_parallel_size_=num_model_parallel,
pipeline_length=self.gpus_per_model
)
def _infer_check_num_gpus(self, trainer):
"""
Infer the number of GPUs per model.
Args:
trainer: The trainer object.
Returns: The appropriate balance for the model
"""
if isinstance(self.balance, list):
if len(self.balance) != trainer.world_size:
raise MisconfigurationException(
"Pipe currently only supports splitting the module onto all available GPUs"
)
# User has defined a balance for his model
return len(self.balance)
# Assume that the user wants to balance his model on all GPUs
return trainer.world_size
def on_accelerator_exit_rpc_process(self, trainer) -> None:
if not trainer.testing:
torch_distrib.barrier() # Ensure we await main process initialization
# Add trainer/configure_optimizers to the pipe model for access in all worker processes
rpc_pipe.PipeModel.trainer = trainer
del rpc_pipe.PipeModel.trainer.model.sequential_module
rpc_pipe.PipeModel.trainer.model.sequential_module = rpc_pipe.PipeModel
rpc_pipe.PipeModel.configure_optimizers = trainer.model.configure_optimizers
super().on_accelerator_exit_rpc_process(trainer)
def set_main_rpc_process(self):
self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0
def on_main_rpc_connection(self, trainer) -> None:
# Create pipe_module
model = trainer.get_model()
self._find_and_init_pipe_module(model)
if not trainer.testing:
torch_distrib.barrier() # Ensure we join main process initialization
model.sequential_module.foreach_worker(register_optimizers, include_self=True)
def _check_arguments(self, trainer):
if trainer.amp_backend is not None:
raise MisconfigurationException(
'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision')
def configure_ddp(
self,
model: LightningModule, device_ids: List[int]) -> DistributedDataParallel:
ddp_plugin = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
# Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
ddp_plugin.PREPARE_FOR_BACKWARDS = False
return ddp_plugin
@rank_zero_only
def rpc_save_model(
self,
save_model_fn,
last_filepath,
trainer,
pl_module) -> None:
model = trainer.get_model()
if not hasattr(model.sequential_module, "foreach_worker"):
return
current_layers = pl_module.sequential_module
model.sequential_module.foreach_worker(
save_layers_on_all_rank_zero_workers,
{"gpus_per_model": self.gpus_per_model},
include_self=True
)
pl_module.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model)
save_model_fn(last_filepath, trainer, pl_module)
pl_module.sequential_module = current_layers
def worker_optimizer_step(
self,
model: LightningModule,
opt_idx: int,
*args,
**kwargs) -> None:
model.sequential_module.foreach_worker(
run_optimizer,
{"opt_idx": opt_idx, "args": args, "kwargs": kwargs},
include_self=False
)
def distributed_sampler_kwargs(self, distributed_sampler_kwargs):
return dict(
num_replicas=mpu.get_data_parallel_world_size(),
rank=mpu.get_data_parallel_rank(),
)
@property
def data_parallel_group(self):
return mpu.get_data_parallel_group()
@property
def is_main_rpc_process(self) -> bool:
return self.main_rpc_process
@property
def return_after_exit_rpc_process(self) -> bool:
return True
def barrier(self, name: Optional[str] = None) -> None:
if torch_distrib.is_initialized() and self.is_main_rpc_process:
torch_distrib.barrier(group=self.data_parallel_group)
def _check_pipe_available(self):
if not FAIRSCALE_PIPE_AVAILABLE:
raise MisconfigurationException(
'PipeRPCPlugin requires FairScale and currently is only supported on PyTorch 1.6.'
)
class LightningPipeModule(nn.Module):
"""
This class wraps Fairscale Pipe and PipeRCPWrapper class.
"""
def __init__(
self,
module: nn.Sequential,
balance: List[int],
microbatches: int = 8,
checkpoint='never'):
super().__init__()
self.module = module
self.balance = balance
self.microbatches = microbatches
self.checkpoint = checkpoint
self._init_pipe()
def _init_pipe(self):
device = torch.device("cuda", torch_distrib.get_rank())
self.module = PipeRPCWrapper(
module=self.module,
balance=self.balance,
chunks=self.microbatches,
style=PipelineStyle.MultiProcess,
input_device=device,
worker_map=self.get_worker_map(),
checkpoint=self.checkpoint,
)
def foreach_worker(self, *args, **kwargs):
self.module.foreach_worker(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def get_worker_map(self):
# TODO, is this correct with multinodes? We also assume "worker" is the same as defined in the RPCPlugin
return {rank: f"worker{rank}" for rank in range(torch_distrib.get_world_size())}
def register_optimizers(ctx, model):
optimizers, lr_schedulers, optimizer_frequencies = model.trainer.init_optimizers(model)
model.trainer.optimizers = optimizers
model.trainer.lr_schedulers = lr_schedulers
model.trainer.optimizer_frequencies = optimizer_frequencies
model.trainer.convert_to_lightning_optimizers()
def run_optimizer(ctx, model):
trainer = model.trainer
opt_idx = ctx["opt_idx"]
optimizer = trainer.optimizers[opt_idx]
optimizer.step(*ctx["args"], **ctx["kwargs"])
def save_layers_on_all_rank_zero_workers(ctx, model):
gpus_per_model = ctx["gpus_per_model"]
rank = torch_distrib.get_rank()
if rank in range(gpus_per_model):
seq = list(model.children())[0]
torch.save(seq, f"seq_{rank}.pt")
def load_sequential_from_saved_layers(gpus_per_model):
partial_seqs = [torch.load(f"seq_{rank}.pt", map_location='cpu') for rank in range(gpus_per_model)]
seq = nn.Sequential()
for p_seq in partial_seqs:
for name, child in p_seq.named_children():
seq.add_module(name, child)
# delete tmp files
[os.remove(f"seq_{rank}.pt") for rank in range(gpus_per_model)]
return seq

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum from enum import Enum
from typing import Union, Optional, List from typing import List, Optional, Union
from pytorch_lightning.cluster_environments import ClusterEnvironment from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.plugins.apex import ApexPlugin from pytorch_lightning.plugins.apex import ApexPlugin
@ -163,16 +163,16 @@ class PluginConnector:
@classmethod @classmethod
def available_plugins(cls): def available_plugins(cls):
""" """
List of all available plugins that can be string arguments to the trainer. List of all available plugins that can be string arguments to the trainer.
Returns: List of all available plugins that are supported as string arguments. Returns: List of all available plugins that are supported as string arguments.
""" """
return [e.name for e in LightningCustomPlugins] return [e.name for e in LightningCustomPlugins]
class LightningCustomPlugins(Enum): class LightningCustomPlugins(Enum):
""" """
String support for custom lightning plugins. String support for custom lightning plugins.
Allows easier access to custom lightning plugins from the command line. Allows easier access to custom lightning plugins from the command line.
""" """
ddp_sharded = DDPShardedPlugin ddp_sharded = DDPShardedPlugin
native_amp = NativeAMPPlugin native_amp = NativeAMPPlugin

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from typing import Optional from typing import Any, Optional
import torch import torch

View File

@ -14,14 +14,14 @@
"""General utilities""" """General utilities"""
import importlib import importlib
import platform import platform
from distutils.version import LooseVersion
from enum import Enum from enum import Enum
import numpy import numpy
import torch import torch
from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import AllGatherGrad
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils
@ -34,14 +34,18 @@ def _module_available(module_path: str) -> bool:
>>> _module_available('bla.bla') >>> _module_available('bla.bla')
False False
""" """
mods = module_path.split('.') # todo: find a better way than try / except
assert mods, 'nothing given to test' try:
# it has to be tested as per partets mods = module_path.split('.')
for i in range(len(mods)): assert mods, 'nothing given to test'
module_path = '.'.join(mods[:i + 1]) # it has to be tested as per partets
if importlib.util.find_spec(module_path) is None: for i in range(len(mods)):
return False module_path = '.'.join(mods[:i + 1])
return True if importlib.util.find_spec(module_path) is None:
return False
return True
except AttributeError:
return False
APEX_AVAILABLE = _module_available("apex.amp") APEX_AVAILABLE = _module_available("apex.amp")
@ -54,6 +58,8 @@ TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc') RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc')
GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group') GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
FAIRSCALE_PIPE_AVAILABLE = FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) == LooseVersion("1.6.0")
BOLTS_AVAILABLE = _module_available('pl_bolts')
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps

View File

@ -7,4 +7,4 @@ torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility
onnx>=1.7.0 onnx>=1.7.0
onnxruntime>=1.3.0 onnxruntime>=1.3.0
hydra-core>=1.0 hydra-core>=1.0
https://github.com/facebookresearch/fairscale/archive/8e85ce8c93569017521d92ceb78dba2c57c955a0.zip # TODO temporary fix till release version https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip

View File

@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
import pytest
import os import os
from tests.base.boring_model import BoringModel
from pytorch_lightning.callbacks import Callback
from pytorch_lightning import accelerators, Trainer
from pytorch_lightning.cluster_environments import SLURMEnvironment, TorchElasticEnvironment, ClusterEnvironment
from pytorch_lightning.accelerators import Accelerator
from unittest import mock from unittest import mock
import pytest
from pytorch_lightning import Trainer, accelerators
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.cluster_environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
from tests.base.boring_model import BoringModel
def test_accelerator_choice_cpu(tmpdir): def test_accelerator_choice_cpu(tmpdir):
class CB(Callback): class CB(Callback):

View File

@ -0,0 +1,212 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
import pytest
import torch
import torch.distributed as torch_distrib
from torch import nn
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
from pytorch_lightning.utilities import FAIRSCALE_PIPE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import RandomDataset
def cleanup(ctx, model):
"""
Cleanup function required to ensure we delete the pipe module at the end of the the test on all workers
"""
del model
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None):
model = SequentialModelRPCManual()
trainer = Trainer(
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
gpus=2,
distributed_backend="ddp",
plugins=[DDPSequentialPlugin(balance=[2, 1])],
)
trainer.fit(model)
if torch_distrib.get_rank() == 0:
assert len(trainer.dev_debugger.pbar_added_metrics) > 0
if trainer.accelerator_backend.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None):
model = SequentialModelRPCManual()
trainer = Trainer(
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
gpus=2,
precision=16,
amp_backend="native",
distributed_backend="ddp",
plugins=[DDPSequentialPlugin(balance=[2, 1])],
)
try:
trainer.fit(model)
assert len(trainer.dev_debugger.pbar_added_metrics) > 0
except MisconfigurationException as e:
assert str(e) == 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision'
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None):
model = SequentialModelRPCAutomatic()
trainer = Trainer(
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
gpus=2,
distributed_backend="ddp",
plugins=[DDPSequentialPlugin(balance=[2, 1])],
)
trainer.fit(model)
if torch_distrib.get_rank() == 0:
assert len(trainer.dev_debugger.pbar_added_metrics) > 0
if trainer.accelerator_backend.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None):
model = SequentialModelRPCAutomatic()
trainer = Trainer(
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
gpus=2,
distributed_backend="ddp",
plugins=[DDPSequentialPlugin(balance=[2, 2])],
)
try:
trainer.fit(model)
except MisconfigurationException as e:
assert str(e) == 'The provided balance sum: 4 does not match your Sequential length: 3'
if trainer.accelerator_backend.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
class SequentialModelRPCManual(LightningModule):
def __init__(self):
super().__init__()
self.sequential_module = nn.Sequential(torch.nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2))
def forward(self, x):
return self.sequential_module(x)
def loss(self, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def step(self, x):
x = self(x)
out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
return out
def training_step(self, batch, batch_idx):
opt = self.optimizers()
output = self.sequential_module(batch)
loss = self.loss(output)
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
self.manual_backward(loss, opt)
assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() > 0
opt.step()
assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() == 0
def validation_step(self, batch, batch_idx):
output = self.sequential_module(batch)
loss = self.loss(output)
return loss
def test_step(self, batch, batch_idx):
output = self.sequential_module(batch)
return self.loss(batch, output)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def val_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def test_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
@property
def automatic_optimization(self) -> bool:
return False
class SequentialModelRPCAutomatic(SequentialModelRPCManual):
def training_step(self, batch, batch_idx):
output = self.sequential_module(batch)
loss = self.loss(output)
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
return loss
@property
def automatic_optimization(self) -> bool:
return True

View File

@ -5,7 +5,7 @@ from unittest import mock
import pytest import pytest
import torch import torch
from pytorch_lightning import Trainer, LightningModule from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import RPC_AVAILABLE from pytorch_lightning.utilities import RPC_AVAILABLE

View File

@ -15,4 +15,8 @@
export PL_RUNNING_SPECIAL_TESTS=1 export PL_RUNNING_SPECIAL_TESTS=1
DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no"
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp
python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance