lightning/tests/plugins/test_rpc_sequential_plugin.py

186 lines
6.2 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
import pytest
import torch
import torch.distributed as torch_distrib
from torch import nn
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import RandomDataset
from tests.helpers.runif import RunIf
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@RunIf(min_gpus=2, special=True, fairscale_pipe=True)
def test_rpc_sequential_plugin_manual(tmpdir):
model = SequentialModelRPCManual()
trainer = Trainer(
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
gpus=2,
distributed_backend="ddp",
plugins=[RPCSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)],
)
trainer.fit(model)
if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0:
assert len(trainer.dev_debugger.pbar_added_metrics) > 0
if trainer.accelerator.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator.training_type_plugin.exit_rpc_process()
@RunIf(min_gpus=2, special=True, fairscale_pipe=True)
def test_rpc_sequential_plugin_manual_amp(tmpdir):
model = SequentialModelRPCManual()
trainer = Trainer(
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
gpus=2,
precision=16,
amp_backend="native",
distributed_backend="ddp",
plugins=[RPCSequentialPlugin(balance=[2, 1])],
)
with pytest.raises(
MisconfigurationException,
match='`RPCSequentialPlugin` is currently not supported in Automatic Mixed Precision'
):
trainer.fit(model)
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@RunIf(min_gpus=2, special=True, fairscale_pipe=True)
def test_rpc_sequential_plugin_automatic(tmpdir):
model = SequentialModelRPCAutomatic()
trainer = Trainer(
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
gpus=2,
distributed_backend="ddp",
plugins=[RPCSequentialPlugin(balance=[2, 1])],
)
trainer.fit(model)
if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0:
assert len(trainer.dev_debugger.pbar_added_metrics) > 0
if trainer.accelerator.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator.training_type_plugin.exit_rpc_process()
@RunIf(min_gpus=2, special=True, fairscale_pipe=True)
def test_rpc_sequential_plugin_with_wrong_balance(tmpdir):
model = SequentialModelRPCAutomatic()
trainer = Trainer(
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
gpus=2,
distributed_backend="ddp",
plugins=[RPCSequentialPlugin(balance=[2, 2])],
)
with pytest.raises(
MisconfigurationException, match="The provided balance sum: 4 does not match your Sequential length: 3"
):
trainer.fit(model)
if trainer.accelerator.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator.training_type_plugin.exit_rpc_process()
class SequentialModelRPCManual(LightningModule):
def __init__(self):
super().__init__()
self.sequential_module = nn.Sequential(torch.nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2))
self.automatic_optimization = False
def forward(self, x):
return self.sequential_module(x)
def loss(self, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def step(self, x):
x = self(x)
out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
return out
def training_step(self, batch, batch_idx):
opt = self.optimizers()
output = self.sequential_module(batch)
loss = self.loss(output)
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
self.manual_backward(loss, opt)
assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() > 0
opt.step()
opt.zero_grad()
assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() == 0
def validation_step(self, batch, batch_idx):
output = self.sequential_module(batch)
loss = self.loss(output)
return loss
def test_step(self, batch, batch_idx):
output = self.sequential_module(batch)
return self.loss(batch, output)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def val_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def test_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
class SequentialModelRPCAutomatic(SequentialModelRPCManual):
def __init__(self):
super().__init__()
self.automatic_optimization = True
def training_step(self, batch, batch_idx):
output = self.sequential_module(batch)
loss = self.loss(output)
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
return loss