227 lines
7.4 KiB
Python
227 lines
7.4 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
Example script of running the experimental DDP Sequential Plugin.
|
|
This script splits a convolutional model onto multiple GPUs, whilst using the internal built in balancer
|
|
to balance across your GPUs.
|
|
|
|
To run:
|
|
python conv_model_sequential_example.py --accelerator ddp --gpus 4 --max_epochs 1 --batch_size 256 --use_rpc_sequential
|
|
"""
|
|
import math
|
|
from argparse import ArgumentParser
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
from torchmetrics.functional import accuracy
|
|
|
|
import pytorch_lightning as pl
|
|
from pl_examples import cli_lightning_logo
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.plugins import RPCSequentialPlugin
|
|
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE
|
|
|
|
if _BOLTS_AVAILABLE:
|
|
import pl_bolts
|
|
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
|
|
|
|
#####################
|
|
# Modules #
|
|
#####################
|
|
|
|
|
|
class Flatten(nn.Module):
|
|
|
|
def forward(self, x):
|
|
return x.view(x.size(0), -1)
|
|
|
|
|
|
###############################
|
|
# LightningModule #
|
|
###############################
|
|
|
|
|
|
class LitResnet(pl.LightningModule):
|
|
"""
|
|
>>> LitResnet() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
|
LitResnet(
|
|
(sequential_module): Sequential(...)
|
|
)
|
|
"""
|
|
|
|
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
|
|
super().__init__()
|
|
|
|
self.save_hyperparameters()
|
|
self.sequential_module = nn.Sequential(
|
|
# Conv Layer block 1
|
|
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(32),
|
|
nn.ReLU(inplace=False),
|
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=False),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
|
|
# Conv Layer block 2
|
|
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(128),
|
|
nn.ReLU(inplace=False),
|
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=False),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
nn.Dropout2d(p=0.05),
|
|
|
|
# Conv Layer block 3
|
|
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(256),
|
|
nn.ReLU(inplace=False),
|
|
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=False),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
Flatten(),
|
|
nn.Dropout(p=0.1),
|
|
nn.Linear(4096, 1024),
|
|
nn.ReLU(inplace=False),
|
|
nn.Linear(1024, 512),
|
|
nn.ReLU(inplace=False),
|
|
nn.Dropout(p=0.1),
|
|
nn.Linear(512, 10)
|
|
)
|
|
self._example_input_array = torch.randn((1, 3, 32, 32))
|
|
|
|
if manual_optimization:
|
|
self.automatic_optimization = False
|
|
self.training_step = self.training_step_manual
|
|
|
|
def forward(self, x):
|
|
out = self.sequential_module(x)
|
|
return F.log_softmax(out, dim=-1)
|
|
|
|
def training_step_manual(self, batch, batch_idx):
|
|
opt = self.optimizers()
|
|
|
|
def closure():
|
|
x, y = batch
|
|
logits = self.forward(x)
|
|
loss = F.nll_loss(logits, y)
|
|
self.manual_backward(loss, opt)
|
|
self.log('train_loss', loss, prog_bar=True)
|
|
|
|
opt.step(closure=closure)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self.forward(x)
|
|
loss = F.nll_loss(logits, y)
|
|
self.log('Training Loss', loss)
|
|
return loss
|
|
|
|
def _evaluate(self, batch, batch_idx, stage=None):
|
|
x, y = batch
|
|
out = self.forward(x)
|
|
logits = F.log_softmax(out, dim=-1)
|
|
loss = F.nll_loss(logits, y)
|
|
preds = torch.argmax(logits, dim=-1)
|
|
acc = accuracy(preds, y)
|
|
|
|
if stage:
|
|
self.log(f'{stage}_loss', loss, prog_bar=True)
|
|
self.log(f'{stage}_acc', acc, prog_bar=True)
|
|
|
|
return loss, acc
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
return self._evaluate(batch, batch_idx, 'val')[0]
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
loss, acc = self._evaluate(batch, batch_idx, 'test')
|
|
self.log_dict({'test_loss': loss, 'test_acc': acc})
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
|
|
return {
|
|
'optimizer': optimizer,
|
|
'lr_scheduler': {
|
|
'scheduler': torch.optim.lr_scheduler.OneCycleLR(
|
|
optimizer,
|
|
0.1,
|
|
epochs=self.trainer.max_epochs,
|
|
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)
|
|
),
|
|
'interval': 'step',
|
|
}
|
|
}
|
|
|
|
|
|
#################################
|
|
# Instantiate Data Module #
|
|
#################################
|
|
|
|
|
|
def instantiate_datamodule(args):
|
|
train_transforms = torchvision.transforms.Compose([
|
|
torchvision.transforms.RandomCrop(32, padding=4),
|
|
torchvision.transforms.RandomHorizontalFlip(),
|
|
torchvision.transforms.ToTensor(),
|
|
cifar10_normalization(),
|
|
])
|
|
|
|
test_transforms = torchvision.transforms.Compose([
|
|
torchvision.transforms.ToTensor(),
|
|
cifar10_normalization(),
|
|
])
|
|
|
|
cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule(
|
|
data_dir=args.data_dir,
|
|
batch_size=args.batch_size,
|
|
train_transforms=train_transforms,
|
|
test_transforms=test_transforms,
|
|
val_transforms=test_transforms,
|
|
)
|
|
|
|
return cifar10_dm
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli_lightning_logo()
|
|
|
|
assert _BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
|
|
assert _FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."
|
|
|
|
parser = ArgumentParser(description="Pipe Example")
|
|
parser.add_argument("--use_rpc_sequential", action="store_true")
|
|
parser.add_argument("--manual_optimization", action="store_true")
|
|
parser = Trainer.add_argparse_args(parser)
|
|
parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser)
|
|
args = parser.parse_args()
|
|
|
|
cifar10_dm = instantiate_datamodule(args)
|
|
|
|
plugins = None
|
|
if args.use_rpc_sequential:
|
|
plugins = RPCSequentialPlugin()
|
|
|
|
model = LitResnet(batch_size=args.batch_size, manual_optimization=args.manual_optimization)
|
|
|
|
trainer = pl.Trainer.from_argparse_args(args, plugins=[plugins] if plugins else None)
|
|
trainer.fit(model, cifar10_dm)
|
|
trainer.test(model, datamodule=cifar10_dm)
|
|
|
|
if trainer.accelerator.rpc_enabled:
|
|
# Called at the end of trainer to ensure all processes are killed
|
|
trainer.training_type_plugin.exit_rpc_process()
|