[feat] pp 2/n (#5026)
* Added changes for RPC plugin * Add missing kwargs * Fix code format * Loading refactors by introducing is_distributed var, fix optimizer step flow * Add rpc guard * Added docstrings and typing * resolve comments * Add additional rpc hook, refactor name of exit process hook for clarity * remove annotation * Modify behaviour to allow optional return, add test for rpc plugin * resolve tests * rename is_ddp_based * update * update for windows * update * resolve test * code smell * Added sequential plugin * resolve bug * update * cleanup * add Exception * resolve docs * Remove ddp support * Revert distributed -> ddp * Update pl_examples/basic_examples/conv_sequential_example.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/basic_examples/conv_sequential_example.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Address code review points * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add missing return * Fix formatting, add datamodule args * add small comment * resolve comments * resolve comments * update source for fairscale * update extras * remove staticmethod * resolve flake8 * Skip tests that are failing due to bug upstream with multiple optimizers and shard * update * update on comments * clean test * latest comments * remove old comments * add todo * Update version * update * resolve bugs * resolve bugs * update test * remove hanging test * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * resolve on comments * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * resolve on comments * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * remove ImportError Co-authored-by: SeanNaren <sean@grid.ai> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
7d9784e951
commit
ef8ef12fd0
|
@ -32,5 +32,6 @@ repos:
|
|||
types: [python]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: master
|
||||
hooks:
|
||||
- id: mypy
|
||||
|
|
|
@ -131,6 +131,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
|
@ -148,6 +149,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
|
@ -189,7 +191,7 @@ class SeedTrainLoaderManualModel(SeedTrainLoaderModel):
|
|||
|
||||
# ensure we forward the correct params to the optimizer
|
||||
# without retain_graph we can't do multiple backward passes
|
||||
self.manual_backward(loss_2, opt_b, retain_graph=True)
|
||||
self.manual_backward(loss_2, opt_b)
|
||||
# todo: understand why synchronization breaks there.
|
||||
# self.manual_backward(loss_2, opt_a, retain_graph=True)
|
||||
opt_b.step()
|
||||
|
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
Example script of running the experimental DDP Sequential Plugin.
|
||||
This script splits a convolutional model onto multiple GPUs, whilst using the internal built in balancer
|
||||
to balance across your GPUs.
|
||||
|
||||
To run:
|
||||
python conv_model_sequential_example.py --accelerator ddp --gpus 4 --max_epochs 1 --batch_size 256 --use_ddp_sequential
|
||||
"""
|
||||
import math
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.metrics.functional import accuracy
|
||||
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
|
||||
from pytorch_lightning.utilities import BOLTS_AVAILABLE, FAIRSCALE_PIPE_AVAILABLE
|
||||
|
||||
if BOLTS_AVAILABLE:
|
||||
import pl_bolts
|
||||
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
|
||||
|
||||
|
||||
#####################
|
||||
# Modules #
|
||||
#####################
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
###############################
|
||||
# LightningModule #
|
||||
###############################
|
||||
|
||||
|
||||
class LitResnet(pl.LightningModule):
|
||||
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters()
|
||||
self.sequential_module = nn.Sequential(
|
||||
# Conv Layer block 1
|
||||
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
# Conv Layer block 2
|
||||
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Dropout2d(p=0.05),
|
||||
|
||||
# Conv Layer block 3
|
||||
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
Flatten(),
|
||||
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(4096, 1024),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(512, 10)
|
||||
)
|
||||
self._example_input_array = torch.randn((1, 3, 32, 32))
|
||||
self._manual_optimization = manual_optimization
|
||||
if self._manual_optimization:
|
||||
self.training_step = self.training_step_manual
|
||||
|
||||
def forward(self, x):
|
||||
out = self.sequential_module(x)
|
||||
return F.log_softmax(out, dim=-1)
|
||||
|
||||
def training_step_manual(self, batch, batch_idx):
|
||||
opt = self.optimizers()
|
||||
|
||||
def closure():
|
||||
x, y = batch
|
||||
logits = self.forward(x)
|
||||
loss = F.nll_loss(logits, y)
|
||||
self.manual_backward(loss, opt)
|
||||
self.log('train_loss', loss, prog_bar=True)
|
||||
|
||||
opt.step(closure=closure)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.forward(x)
|
||||
loss = F.nll_loss(logits, y)
|
||||
self.log('Training Loss', loss)
|
||||
return loss
|
||||
|
||||
def _evaluate(self, batch, batch_idx, stage=None):
|
||||
x, y = batch
|
||||
out = self.forward(x)
|
||||
logits = F.log_softmax(out, dim=-1)
|
||||
loss = F.nll_loss(logits, y)
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
acc = accuracy(preds, y)
|
||||
|
||||
if stage:
|
||||
self.log(f'{stage}_loss', loss, prog_bar=True)
|
||||
self.log(f'{stage}_acc', acc, prog_bar=True)
|
||||
|
||||
return loss, acc
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return self._evaluate(batch, batch_idx, 'val')[0]
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
loss, acc = self._evaluate(batch, batch_idx, 'test')
|
||||
self.log_dict({'test_loss': loss, 'test_acc': acc})
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
|
||||
return {
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler': {
|
||||
'scheduler': torch.optim.lr_scheduler.OneCycleLR(
|
||||
optimizer,
|
||||
0.1,
|
||||
epochs=self.trainer.max_epochs,
|
||||
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)),
|
||||
'interval': 'step',
|
||||
}
|
||||
}
|
||||
|
||||
@property
|
||||
def automatic_optimization(self) -> bool:
|
||||
return not self._manual_optimization
|
||||
|
||||
|
||||
#################################
|
||||
# Instantiate Data Module #
|
||||
#################################
|
||||
|
||||
def instantiate_datamodule(args):
|
||||
train_transforms = torchvision.transforms.Compose([
|
||||
torchvision.transforms.RandomCrop(32, padding=4),
|
||||
torchvision.transforms.RandomHorizontalFlip(),
|
||||
torchvision.transforms.ToTensor(),
|
||||
cifar10_normalization(),
|
||||
])
|
||||
|
||||
test_transforms = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
cifar10_normalization(),
|
||||
])
|
||||
|
||||
cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule(
|
||||
batch_size=args.batch_size,
|
||||
train_transforms=train_transforms,
|
||||
test_transforms=test_transforms,
|
||||
val_transforms=test_transforms,
|
||||
)
|
||||
|
||||
return cifar10_dm
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description="Pipe Example")
|
||||
parser.add_argument("--use_ddp_sequential", action="store_true")
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
|
||||
assert FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."
|
||||
|
||||
cifar10_dm = instantiate_datamodule(args)
|
||||
|
||||
plugins = None
|
||||
if args.use_ddp_sequential:
|
||||
plugins = DDPSequentialPlugin()
|
||||
|
||||
model = LitResnet(batch_size=args.batch_size, manual_optimization=not args.automatic_optimization)
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(args, plugins=[plugins] if plugins else None)
|
||||
trainer.fit(model, cifar10_dm)
|
||||
trainer.test(model, datamodule=cifar10_dm)
|
||||
|
||||
if trainer.accelerator_backend.rpc_enabled:
|
||||
# Called at the end of trainer to ensure all processes are killed
|
||||
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
|
|
@ -155,6 +155,7 @@ class LightningDistributedDataParallel(DistributedDataParallel):
|
|||
"""
|
||||
Override the forward call in lightning so it goes to training and validation step respectively
|
||||
"""
|
||||
PREPARE_FOR_BACKWARDS = True
|
||||
|
||||
def parallel_apply(self, replicas, inputs, kwargs):
|
||||
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
||||
|
@ -165,6 +166,7 @@ class LightningDistributedDataParallel(DistributedDataParallel):
|
|||
fx_called: str = ''
|
||||
|
||||
if self.device_ids:
|
||||
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
if len(self.device_ids) == 1:
|
||||
# --------------
|
||||
|
@ -195,7 +197,7 @@ class LightningDistributedDataParallel(DistributedDataParallel):
|
|||
else:
|
||||
output = self.module.validation_step(*inputs, **kwargs)
|
||||
|
||||
if not self._reducer_prepared_for_backwards:
|
||||
if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS:
|
||||
self.reducer_prepare_for_backwards(output)
|
||||
|
||||
if output is None:
|
||||
|
|
|
@ -0,0 +1,409 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if FAIRSCALE_PIPE_AVAILABLE:
|
||||
import fairscale.nn.model_parallel as mpu
|
||||
from fairscale.nn import PipeRPCWrapper
|
||||
from fairscale.nn.pipe import balance as pipe_balance
|
||||
from fairscale.nn.pipe import rpc as rpc_pipe
|
||||
from fairscale.nn.pipe.pipeline import PipelineStyle
|
||||
|
||||
|
||||
class DDPSequentialPlugin(RPCPlugin):
|
||||
def __init__(
|
||||
self,
|
||||
balance: Optional[List[int]] = None,
|
||||
microbatches: int = 8,
|
||||
checkpoint: str = 'except_last',
|
||||
balance_mode: str = "balance_by_size",
|
||||
pipelined_backward: Optional[bool] = True,
|
||||
**kwargs):
|
||||
"""
|
||||
Provides sequential model parallelism for :class:`nn.Sequential <torch.nn.Sequential>` module.
|
||||
If the module requires lots of memory, Pipe can be used to reduce this by leveraging multiple GPUs.
|
||||
|
||||
Example::
|
||||
class MyLightningModule:
|
||||
def __init__(self):
|
||||
...
|
||||
model.sequential_module = torch.nn.Sequential(my_layers)
|
||||
|
||||
# Split my module across 4 gpus, one layer each
|
||||
model = MyLightningModule()
|
||||
plugin = DDPSequentialPlugin(balance=[1, 1, 1, 1])
|
||||
trainer = Trainer(accelerator='ddp', gpus=4, plugins=[plugin])
|
||||
trainer.fit(model)
|
||||
|
||||
.. _DDPSequentialPlugin: https://arxiv.org/abs/1811.06965
|
||||
|
||||
Pipeline parallelism comes with with checkpointing to reduce peak
|
||||
memory required to train while minimizing device under-utilization.
|
||||
This is turned on by default and can be turned off via the checkpoint argument.
|
||||
|
||||
You should determine the balance when defining the plugin,
|
||||
or you can pass an example input array via the LightningModule to infer a balance.
|
||||
The module will be partitioned into multiple devices according to the given balance. You may also rely on
|
||||
your own heuristics to find your own optimal configuration.
|
||||
|
||||
Args:
|
||||
balance: The balance of the model, i.e [2, 2] (two layers on each GPU).
|
||||
If not provided assumes user provides an input example array to find a balance on all GPUs.
|
||||
|
||||
microbatches: Allows for parallelization to reduce device utilization
|
||||
by splitting the batch into further smaller batches.
|
||||
|
||||
checkpoint: Enables gradient checkpointing. ['always', 'except_last', 'never']
|
||||
|
||||
balance_mode: Type of balance heuristic to use if balance to be inferred.
|
||||
|
||||
- 'balance_by_size': checks memory usage of each layer and determines balance
|
||||
|
||||
- 'balance_by_time': checks time of each layer and determines balance
|
||||
|
||||
pipelined_backward: if True, call torch.autograd.backward once per microbatch on the
|
||||
|
||||
backward pass (instead of once for the whole batch). This works
|
||||
around a potential deadlock in pytorch when using tensor parallelism
|
||||
at the same time. Defaults to `True` if
|
||||
`get_model_parallel_world_size() > 1`
|
||||
"""
|
||||
self._check_pipe_available()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.balance = balance
|
||||
|
||||
self.microbatches = microbatches
|
||||
self.checkpoint = checkpoint
|
||||
self.balance_mode = balance_mode
|
||||
self.pipelined_backward = pipelined_backward
|
||||
self.main_rpc_process = False # Updated by main process, default for all secondary processes
|
||||
|
||||
def init_ddp_connection(
|
||||
self,
|
||||
trainer,
|
||||
cluster_environment,
|
||||
global_rank: int,
|
||||
world_size: int,
|
||||
is_slurm_managing_tasks: bool = True,
|
||||
) -> None:
|
||||
trainer.prepared_for_backwards = False
|
||||
self._check_arguments(trainer)
|
||||
if self._skip_init_connections(trainer):
|
||||
return
|
||||
super().init_ddp_connection(
|
||||
trainer=trainer,
|
||||
cluster_environment=cluster_environment,
|
||||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
is_slurm_managing_tasks=is_slurm_managing_tasks
|
||||
)
|
||||
super().init_rpc_connection(
|
||||
global_rank=global_rank,
|
||||
world_size=world_size
|
||||
)
|
||||
model = trainer.get_model()
|
||||
self.gpus_per_model = self._infer_check_num_gpus(trainer)
|
||||
self.init_model_parallel_groups(trainer)
|
||||
self.set_main_rpc_process()
|
||||
|
||||
self._check_sequential_model_exists(model)
|
||||
if self.main_rpc_process:
|
||||
if self.balance is None:
|
||||
self._infer_model_balance(trainer)
|
||||
self._assert_valid_model_balance(trainer)
|
||||
|
||||
def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
|
||||
pass
|
||||
|
||||
def _infer_model_balance(self, trainer):
|
||||
log.info(f'Inferring model balance using {self.balance_mode} mode')
|
||||
model = trainer.get_model()
|
||||
if model.example_input_array is None:
|
||||
raise MisconfigurationException(
|
||||
'Please set example_input_array to your model, so we can infer the right model balance for you')
|
||||
balance_func = getattr(pipe_balance, self.balance_mode)
|
||||
self.balance = balance_func(self.gpus_per_model, model.sequential_module, model.example_input_array)
|
||||
self._sync_balance_to_all_parallel_groups()
|
||||
|
||||
log.info(f'The following model balance {self.balance.tolist()} was inferred using {self.balance_mode} mode')
|
||||
|
||||
def _sync_balance_to_all_parallel_groups(self, main_rank=0):
|
||||
"""
|
||||
Ensures that we sync the balance to all main processes, so that the balance is the same per replica.
|
||||
Args:
|
||||
main_rank: The rank with the balance we'd like to replicate.
|
||||
"""
|
||||
self.balance = torch.tensor(self.balance, dtype=torch.int, device='cuda')
|
||||
# Ensure we sync to all processes within the main data parallel group
|
||||
# We use the data parallel group as all main processes are found within the same group
|
||||
torch_distrib.broadcast(self.balance, src=main_rank, group=mpu.get_data_parallel_group())
|
||||
self.balance = self.balance.cpu()
|
||||
|
||||
def _check_sequential_model_exists(self, model):
|
||||
if not hasattr(model, "sequential_module") or not isinstance(model.sequential_module, nn.Sequential):
|
||||
raise MisconfigurationException(
|
||||
'Could not find a PipeLightningModule within the model. '
|
||||
'Did you set your sequential model as the `sequential_module` attribute of your model?')
|
||||
|
||||
def _find_and_init_pipe_module(self, model):
|
||||
if hasattr(model, "sequential_module") and isinstance(model.sequential_module, LightningPipeModule):
|
||||
# model has been wrapped already
|
||||
return
|
||||
elif hasattr(model, "sequential_module") and isinstance(model.sequential_module, nn.Sequential):
|
||||
# try to wrap model for the user
|
||||
model.sequential_module = LightningPipeModule(
|
||||
model.sequential_module,
|
||||
balance=self.balance,
|
||||
microbatches=self.microbatches,
|
||||
checkpoint=self.checkpoint,
|
||||
)
|
||||
# Update references for workers to access correct lightning functions when calling RPC
|
||||
model.sequential_module.trainer = model.trainer
|
||||
model.sequential_module.configure_optimizers = model.configure_optimizers
|
||||
|
||||
# Update references for main process to access correct lightning functions when calling RPC
|
||||
model.sequential_module.module.model.trainer = model.trainer
|
||||
model.sequential_module.module.model.configure_optimizers = model.configure_optimizers
|
||||
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
'Could not find a PipeLightningModule within the model. '
|
||||
'Did you defined set your sequential model as an `sequential_module` attribute of your model ?'
|
||||
)
|
||||
|
||||
def _assert_valid_model_balance(self, trainer):
|
||||
model = trainer.get_model()
|
||||
if sum(self.balance) != len(model.sequential_module):
|
||||
raise MisconfigurationException(
|
||||
f'The provided balance sum: {sum(self.balance)} does not'
|
||||
f' match your Sequential length: {len(model.sequential_module)}')
|
||||
|
||||
def _skip_init_connections(self, trainer):
|
||||
"""
|
||||
Skip initialization if torch is already initialized and we're in testing.
|
||||
Returns: Whether to skip initialization
|
||||
|
||||
"""
|
||||
return torch_distrib.is_initialized() and trainer.testing
|
||||
|
||||
def init_model_parallel_groups(self, trainer):
|
||||
num_model_parallel = 1 # TODO currently no support for vertical model parallel
|
||||
mpu.initialize_model_parallel(
|
||||
model_parallel_size_=num_model_parallel,
|
||||
pipeline_length=self.gpus_per_model
|
||||
)
|
||||
|
||||
def _infer_check_num_gpus(self, trainer):
|
||||
"""
|
||||
Infer the number of GPUs per model.
|
||||
|
||||
Args:
|
||||
trainer: The trainer object.
|
||||
|
||||
Returns: The appropriate balance for the model
|
||||
"""
|
||||
if isinstance(self.balance, list):
|
||||
if len(self.balance) != trainer.world_size:
|
||||
raise MisconfigurationException(
|
||||
"Pipe currently only supports splitting the module onto all available GPUs"
|
||||
)
|
||||
# User has defined a balance for his model
|
||||
return len(self.balance)
|
||||
# Assume that the user wants to balance his model on all GPUs
|
||||
return trainer.world_size
|
||||
|
||||
def on_accelerator_exit_rpc_process(self, trainer) -> None:
|
||||
if not trainer.testing:
|
||||
torch_distrib.barrier() # Ensure we await main process initialization
|
||||
|
||||
# Add trainer/configure_optimizers to the pipe model for access in all worker processes
|
||||
rpc_pipe.PipeModel.trainer = trainer
|
||||
del rpc_pipe.PipeModel.trainer.model.sequential_module
|
||||
rpc_pipe.PipeModel.trainer.model.sequential_module = rpc_pipe.PipeModel
|
||||
rpc_pipe.PipeModel.configure_optimizers = trainer.model.configure_optimizers
|
||||
super().on_accelerator_exit_rpc_process(trainer)
|
||||
|
||||
def set_main_rpc_process(self):
|
||||
self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0
|
||||
|
||||
def on_main_rpc_connection(self, trainer) -> None:
|
||||
# Create pipe_module
|
||||
model = trainer.get_model()
|
||||
self._find_and_init_pipe_module(model)
|
||||
if not trainer.testing:
|
||||
torch_distrib.barrier() # Ensure we join main process initialization
|
||||
model.sequential_module.foreach_worker(register_optimizers, include_self=True)
|
||||
|
||||
def _check_arguments(self, trainer):
|
||||
if trainer.amp_backend is not None:
|
||||
raise MisconfigurationException(
|
||||
'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision')
|
||||
|
||||
def configure_ddp(
|
||||
self,
|
||||
model: LightningModule, device_ids: List[int]) -> DistributedDataParallel:
|
||||
ddp_plugin = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
|
||||
# Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
|
||||
ddp_plugin.PREPARE_FOR_BACKWARDS = False
|
||||
return ddp_plugin
|
||||
|
||||
@rank_zero_only
|
||||
def rpc_save_model(
|
||||
self,
|
||||
save_model_fn,
|
||||
last_filepath,
|
||||
trainer,
|
||||
pl_module) -> None:
|
||||
model = trainer.get_model()
|
||||
if not hasattr(model.sequential_module, "foreach_worker"):
|
||||
return
|
||||
current_layers = pl_module.sequential_module
|
||||
model.sequential_module.foreach_worker(
|
||||
save_layers_on_all_rank_zero_workers,
|
||||
{"gpus_per_model": self.gpus_per_model},
|
||||
include_self=True
|
||||
)
|
||||
pl_module.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model)
|
||||
save_model_fn(last_filepath, trainer, pl_module)
|
||||
pl_module.sequential_module = current_layers
|
||||
|
||||
def worker_optimizer_step(
|
||||
self,
|
||||
model: LightningModule,
|
||||
opt_idx: int,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
model.sequential_module.foreach_worker(
|
||||
run_optimizer,
|
||||
{"opt_idx": opt_idx, "args": args, "kwargs": kwargs},
|
||||
include_self=False
|
||||
)
|
||||
|
||||
def distributed_sampler_kwargs(self, distributed_sampler_kwargs):
|
||||
return dict(
|
||||
num_replicas=mpu.get_data_parallel_world_size(),
|
||||
rank=mpu.get_data_parallel_rank(),
|
||||
)
|
||||
|
||||
@property
|
||||
def data_parallel_group(self):
|
||||
return mpu.get_data_parallel_group()
|
||||
|
||||
@property
|
||||
def is_main_rpc_process(self) -> bool:
|
||||
return self.main_rpc_process
|
||||
|
||||
@property
|
||||
def return_after_exit_rpc_process(self) -> bool:
|
||||
return True
|
||||
|
||||
def barrier(self, name: Optional[str] = None) -> None:
|
||||
if torch_distrib.is_initialized() and self.is_main_rpc_process:
|
||||
torch_distrib.barrier(group=self.data_parallel_group)
|
||||
|
||||
def _check_pipe_available(self):
|
||||
if not FAIRSCALE_PIPE_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
'PipeRPCPlugin requires FairScale and currently is only supported on PyTorch 1.6.'
|
||||
)
|
||||
|
||||
|
||||
class LightningPipeModule(nn.Module):
|
||||
"""
|
||||
This class wraps Fairscale Pipe and PipeRCPWrapper class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Sequential,
|
||||
balance: List[int],
|
||||
microbatches: int = 8,
|
||||
checkpoint='never'):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.balance = balance
|
||||
self.microbatches = microbatches
|
||||
self.checkpoint = checkpoint
|
||||
self._init_pipe()
|
||||
|
||||
def _init_pipe(self):
|
||||
device = torch.device("cuda", torch_distrib.get_rank())
|
||||
|
||||
self.module = PipeRPCWrapper(
|
||||
module=self.module,
|
||||
balance=self.balance,
|
||||
chunks=self.microbatches,
|
||||
style=PipelineStyle.MultiProcess,
|
||||
input_device=device,
|
||||
worker_map=self.get_worker_map(),
|
||||
checkpoint=self.checkpoint,
|
||||
)
|
||||
|
||||
def foreach_worker(self, *args, **kwargs):
|
||||
self.module.foreach_worker(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def get_worker_map(self):
|
||||
# TODO, is this correct with multinodes? We also assume "worker" is the same as defined in the RPCPlugin
|
||||
return {rank: f"worker{rank}" for rank in range(torch_distrib.get_world_size())}
|
||||
|
||||
|
||||
def register_optimizers(ctx, model):
|
||||
optimizers, lr_schedulers, optimizer_frequencies = model.trainer.init_optimizers(model)
|
||||
model.trainer.optimizers = optimizers
|
||||
model.trainer.lr_schedulers = lr_schedulers
|
||||
model.trainer.optimizer_frequencies = optimizer_frequencies
|
||||
model.trainer.convert_to_lightning_optimizers()
|
||||
|
||||
|
||||
def run_optimizer(ctx, model):
|
||||
trainer = model.trainer
|
||||
opt_idx = ctx["opt_idx"]
|
||||
optimizer = trainer.optimizers[opt_idx]
|
||||
optimizer.step(*ctx["args"], **ctx["kwargs"])
|
||||
|
||||
|
||||
def save_layers_on_all_rank_zero_workers(ctx, model):
|
||||
gpus_per_model = ctx["gpus_per_model"]
|
||||
rank = torch_distrib.get_rank()
|
||||
if rank in range(gpus_per_model):
|
||||
seq = list(model.children())[0]
|
||||
torch.save(seq, f"seq_{rank}.pt")
|
||||
|
||||
|
||||
def load_sequential_from_saved_layers(gpus_per_model):
|
||||
partial_seqs = [torch.load(f"seq_{rank}.pt", map_location='cpu') for rank in range(gpus_per_model)]
|
||||
seq = nn.Sequential()
|
||||
for p_seq in partial_seqs:
|
||||
for name, child in p_seq.named_children():
|
||||
seq.add_module(name, child)
|
||||
# delete tmp files
|
||||
[os.remove(f"seq_{rank}.pt") for rank in range(gpus_per_model)]
|
||||
return seq
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from enum import Enum
|
||||
from typing import Union, Optional, List
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||
|
@ -163,16 +163,16 @@ class PluginConnector:
|
|||
@classmethod
|
||||
def available_plugins(cls):
|
||||
"""
|
||||
List of all available plugins that can be string arguments to the trainer.
|
||||
Returns: List of all available plugins that are supported as string arguments.
|
||||
List of all available plugins that can be string arguments to the trainer.
|
||||
Returns: List of all available plugins that are supported as string arguments.
|
||||
"""
|
||||
return [e.name for e in LightningCustomPlugins]
|
||||
|
||||
|
||||
class LightningCustomPlugins(Enum):
|
||||
"""
|
||||
String support for custom lightning plugins.
|
||||
Allows easier access to custom lightning plugins from the command line.
|
||||
String support for custom lightning plugins.
|
||||
Allows easier access to custom lightning plugins from the command line.
|
||||
"""
|
||||
ddp_sharded = DDPShardedPlugin
|
||||
native_amp = NativeAMPPlugin
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
@ -14,14 +14,14 @@
|
|||
"""General utilities"""
|
||||
import importlib
|
||||
import platform
|
||||
from distutils.version import LooseVersion
|
||||
from enum import Enum
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.distributed import AllGatherGrad
|
||||
from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
|
||||
from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils
|
||||
|
||||
|
@ -34,14 +34,18 @@ def _module_available(module_path: str) -> bool:
|
|||
>>> _module_available('bla.bla')
|
||||
False
|
||||
"""
|
||||
mods = module_path.split('.')
|
||||
assert mods, 'nothing given to test'
|
||||
# it has to be tested as per partets
|
||||
for i in range(len(mods)):
|
||||
module_path = '.'.join(mods[:i + 1])
|
||||
if importlib.util.find_spec(module_path) is None:
|
||||
return False
|
||||
return True
|
||||
# todo: find a better way than try / except
|
||||
try:
|
||||
mods = module_path.split('.')
|
||||
assert mods, 'nothing given to test'
|
||||
# it has to be tested as per partets
|
||||
for i in range(len(mods)):
|
||||
module_path = '.'.join(mods[:i + 1])
|
||||
if importlib.util.find_spec(module_path) is None:
|
||||
return False
|
||||
return True
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
APEX_AVAILABLE = _module_available("apex.amp")
|
||||
|
@ -54,6 +58,8 @@ TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
|
|||
FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
|
||||
RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc')
|
||||
GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
|
||||
FAIRSCALE_PIPE_AVAILABLE = FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) == LooseVersion("1.6.0")
|
||||
BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||
|
||||
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
|
||||
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
|
||||
|
|
|
@ -7,4 +7,4 @@ torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility
|
|||
onnx>=1.7.0
|
||||
onnxruntime>=1.3.0
|
||||
hydra-core>=1.0
|
||||
https://github.com/facebookresearch/fairscale/archive/8e85ce8c93569017521d92ceb78dba2c57c955a0.zip # TODO temporary fix till release version
|
||||
https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip
|
|
@ -12,15 +12,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from tests.base.boring_model import BoringModel
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning import accelerators, Trainer
|
||||
from pytorch_lightning.cluster_environments import SLURMEnvironment, TorchElasticEnvironment, ClusterEnvironment
|
||||
from pytorch_lightning.accelerators import Accelerator
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer, accelerators
|
||||
from pytorch_lightning.accelerators import Accelerator
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
def test_accelerator_choice_cpu(tmpdir):
|
||||
class CB(Callback):
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
from torch import nn
|
||||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
|
||||
from pytorch_lightning.utilities import FAIRSCALE_PIPE_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import RandomDataset
|
||||
|
||||
|
||||
def cleanup(ctx, model):
|
||||
"""
|
||||
Cleanup function required to ensure we delete the pipe module at the end of the the test on all workers
|
||||
"""
|
||||
del model
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None):
|
||||
model = SequentialModelRPCManual()
|
||||
trainer = Trainer(
|
||||
max_epochs=2,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
gpus=2,
|
||||
distributed_backend="ddp",
|
||||
plugins=[DDPSequentialPlugin(balance=[2, 1])],
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
if torch_distrib.get_rank() == 0:
|
||||
assert len(trainer.dev_debugger.pbar_added_metrics) > 0
|
||||
|
||||
if trainer.accelerator_backend.rpc_enabled:
|
||||
# Called at the end of trainer to ensure all processes are killed
|
||||
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None):
|
||||
model = SequentialModelRPCManual()
|
||||
trainer = Trainer(
|
||||
max_epochs=2,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
gpus=2,
|
||||
precision=16,
|
||||
amp_backend="native",
|
||||
distributed_backend="ddp",
|
||||
plugins=[DDPSequentialPlugin(balance=[2, 1])],
|
||||
)
|
||||
try:
|
||||
trainer.fit(model)
|
||||
|
||||
assert len(trainer.dev_debugger.pbar_added_metrics) > 0
|
||||
|
||||
except MisconfigurationException as e:
|
||||
assert str(e) == 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision'
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None):
|
||||
model = SequentialModelRPCAutomatic()
|
||||
trainer = Trainer(
|
||||
max_epochs=2,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
gpus=2,
|
||||
distributed_backend="ddp",
|
||||
plugins=[DDPSequentialPlugin(balance=[2, 1])],
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
if torch_distrib.get_rank() == 0:
|
||||
assert len(trainer.dev_debugger.pbar_added_metrics) > 0
|
||||
|
||||
if trainer.accelerator_backend.rpc_enabled:
|
||||
|
||||
# Called at the end of trainer to ensure all processes are killed
|
||||
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None):
|
||||
model = SequentialModelRPCAutomatic()
|
||||
trainer = Trainer(
|
||||
max_epochs=2,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
gpus=2,
|
||||
distributed_backend="ddp",
|
||||
plugins=[DDPSequentialPlugin(balance=[2, 2])],
|
||||
)
|
||||
|
||||
try:
|
||||
trainer.fit(model)
|
||||
|
||||
except MisconfigurationException as e:
|
||||
assert str(e) == 'The provided balance sum: 4 does not match your Sequential length: 3'
|
||||
|
||||
if trainer.accelerator_backend.rpc_enabled:
|
||||
# Called at the end of trainer to ensure all processes are killed
|
||||
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
|
||||
|
||||
|
||||
class SequentialModelRPCManual(LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sequential_module = nn.Sequential(torch.nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2))
|
||||
|
||||
def forward(self, x):
|
||||
return self.sequential_module(x)
|
||||
|
||||
def loss(self, prediction):
|
||||
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
|
||||
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
|
||||
|
||||
def step(self, x):
|
||||
x = self(x)
|
||||
out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
|
||||
return out
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
opt = self.optimizers()
|
||||
output = self.sequential_module(batch)
|
||||
loss = self.loss(output)
|
||||
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
|
||||
self.manual_backward(loss, opt)
|
||||
assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() > 0
|
||||
opt.step()
|
||||
assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() == 0
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
output = self.sequential_module(batch)
|
||||
loss = self.loss(output)
|
||||
return loss
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
output = self.sequential_module(batch)
|
||||
return self.loss(batch, output)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=0.1)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def train_dataloader(self):
|
||||
return torch.utils.data.DataLoader(RandomDataset(32, 64))
|
||||
|
||||
def val_dataloader(self):
|
||||
return torch.utils.data.DataLoader(RandomDataset(32, 64))
|
||||
|
||||
def test_dataloader(self):
|
||||
return torch.utils.data.DataLoader(RandomDataset(32, 64))
|
||||
|
||||
@property
|
||||
def automatic_optimization(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class SequentialModelRPCAutomatic(SequentialModelRPCManual):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.sequential_module(batch)
|
||||
loss = self.loss(output)
|
||||
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
|
||||
return loss
|
||||
|
||||
@property
|
||||
def automatic_optimization(self) -> bool:
|
||||
return True
|
|
@ -5,7 +5,7 @@ from unittest import mock
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import RPC_AVAILABLE
|
||||
|
|
|
@ -15,4 +15,8 @@
|
|||
export PL_RUNNING_SPECIAL_TESTS=1
|
||||
DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no"
|
||||
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp
|
||||
python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
|
||||
python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
|
||||
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual
|
||||
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
|
||||
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
|
||||
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
|
||||
|
|
Loading…
Reference in New Issue