feat(semseg): allow model customization (#1371)
* feat(semantic_segmentation): allow customization of unet * feat(semseg): allow model customization * style(semseg): format to PEP8 * fix(semseg): rename logger * docs(changelog): updated semantic segmentation example * suggestions * suggestions * flake8 Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
parent
e3001a0929
commit
06e6eadfaf
|
@ -17,14 +17,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))
|
||||
|
||||
- Updated semantic segmentation example with custom u-net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371))
|
||||
|
||||
-
|
||||
|
||||
### Deprecated
|
||||
|
||||
-
|
||||
|
||||
-
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
|
|
|
@ -7,9 +7,14 @@ import torch.nn.functional as F
|
|||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import random
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pl_examples.models.unet import UNet
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
|
||||
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)
|
||||
|
||||
|
||||
class KITTI(Dataset):
|
||||
|
@ -34,14 +39,16 @@ class KITTI(Dataset):
|
|||
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
|
||||
(mask does not usually require transforms, but they can be implemented in a similar way).
|
||||
"""
|
||||
IMAGE_PATH = os.path.join('training', 'image_2')
|
||||
MASK_PATH = os.path.join('training', 'semantic')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_path,
|
||||
split='test',
|
||||
img_size=(1242, 376),
|
||||
void_labels=[0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1],
|
||||
valid_labels=[7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33],
|
||||
data_path: str,
|
||||
split: str,
|
||||
img_size: tuple = (1242, 376),
|
||||
void_labels: list = DEFAULT_VOID_LABELS,
|
||||
valid_labels: list = DEFAULT_VALID_LABELS,
|
||||
transform=None
|
||||
):
|
||||
self.img_size = img_size
|
||||
|
@ -49,22 +56,23 @@ class KITTI(Dataset):
|
|||
self.valid_labels = valid_labels
|
||||
self.ignore_index = 250
|
||||
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
|
||||
self.split = split
|
||||
self.root = root_path
|
||||
if self.split == 'train':
|
||||
self.img_path = os.path.join(self.root, 'training/image_2')
|
||||
self.mask_path = os.path.join(self.root, 'training/semantic')
|
||||
else:
|
||||
self.img_path = os.path.join(self.root, 'testing/image_2')
|
||||
self.mask_path = None
|
||||
|
||||
self.transform = transform
|
||||
|
||||
self.split = split
|
||||
self.data_path = data_path
|
||||
self.img_path = os.path.join(self.data_path, self.IMAGE_PATH)
|
||||
self.mask_path = os.path.join(self.data_path, self.MASK_PATH)
|
||||
self.img_list = self.get_filenames(self.img_path)
|
||||
self.mask_list = self.get_filenames(self.mask_path)
|
||||
|
||||
# Split between train and valid set (80/20)
|
||||
random_inst = random.Random(12345) # for repeatability
|
||||
n_items = len(self.img_list)
|
||||
idxs = random_inst.sample(range(n_items), n_items // 5)
|
||||
if self.split == 'train':
|
||||
self.mask_list = self.get_filenames(self.mask_path)
|
||||
else:
|
||||
self.mask_list = None
|
||||
idxs = [idx for idx in range(n_items) if idx not in idxs]
|
||||
self.img_list = [self.img_list[i] for i in idxs]
|
||||
self.mask_list = [self.mask_list[i] for i in idxs]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_list)
|
||||
|
@ -74,19 +82,15 @@ class KITTI(Dataset):
|
|||
img = img.resize(self.img_size)
|
||||
img = np.array(img)
|
||||
|
||||
if self.split == 'train':
|
||||
mask = Image.open(self.mask_list[idx]).convert('L')
|
||||
mask = mask.resize(self.img_size)
|
||||
mask = np.array(mask)
|
||||
mask = self.encode_segmap(mask)
|
||||
mask = Image.open(self.mask_list[idx]).convert('L')
|
||||
mask = mask.resize(self.img_size)
|
||||
mask = np.array(mask)
|
||||
mask = self.encode_segmap(mask)
|
||||
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.split == 'train':
|
||||
return img, mask
|
||||
else:
|
||||
return img
|
||||
return img, mask
|
||||
|
||||
def encode_segmap(self, mask):
|
||||
"""
|
||||
|
@ -96,6 +100,8 @@ class KITTI(Dataset):
|
|||
mask[mask == voidc] = self.ignore_index
|
||||
for validc in self.valid_labels:
|
||||
mask[mask == validc] = self.class_map[validc]
|
||||
# remove extra idxs from updated dataset
|
||||
mask[mask > 18] = self.ignore_index
|
||||
return mask
|
||||
|
||||
def get_filenames(self, path):
|
||||
|
@ -124,17 +130,19 @@ class SegModel(pl.LightningModule):
|
|||
|
||||
def __init__(self, hparams):
|
||||
super().__init__()
|
||||
self.root_path = hparams.root
|
||||
self.hparams = hparams
|
||||
self.data_path = hparams.data_path
|
||||
self.batch_size = hparams.batch_size
|
||||
self.learning_rate = hparams.lr
|
||||
self.net = UNet(num_classes=19)
|
||||
self.net = UNet(num_classes=19, num_layers=hparams.num_layers,
|
||||
features_start=hparams.features_start, bilinear=hparams.bilinear)
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
|
||||
std=[0.32064945, 0.32098866, 0.32325324])
|
||||
])
|
||||
self.trainset = KITTI(self.root_path, split='train', transform=self.transform)
|
||||
self.testset = KITTI(self.root_path, split='test', transform=self.transform)
|
||||
self.trainset = KITTI(self.data_path, split='train', transform=self.transform)
|
||||
self.validset = KITTI(self.data_path, split='valid', transform=self.transform)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
@ -145,7 +153,21 @@ class SegModel(pl.LightningModule):
|
|||
mask = mask.long()
|
||||
out = self(img)
|
||||
loss_val = F.cross_entropy(out, mask, ignore_index=250)
|
||||
return {'loss': loss_val}
|
||||
log_dict = {'train_loss': loss_val}
|
||||
return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
img, mask = batch
|
||||
img = img.float()
|
||||
mask = mask.long()
|
||||
out = self(img)
|
||||
loss_val = F.cross_entropy(out, mask, ignore_index=250)
|
||||
return {'val_loss': loss_val}
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
loss_val = sum(output['val_loss'] for output in outputs) / len(outputs)
|
||||
log_dict = {'val_loss': loss_val}
|
||||
return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict}
|
||||
|
||||
def configure_optimizers(self):
|
||||
opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
|
||||
|
@ -155,8 +177,8 @@ class SegModel(pl.LightningModule):
|
|||
def train_dataloader(self):
|
||||
return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False)
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
|
||||
def main(hparams):
|
||||
|
@ -166,24 +188,49 @@ def main(hparams):
|
|||
model = SegModel(hparams)
|
||||
|
||||
# ------------------------
|
||||
# 2 INIT TRAINER
|
||||
# 2 SET LOGGER
|
||||
# ------------------------
|
||||
logger = False
|
||||
if hparams.log_wandb:
|
||||
logger = WandbLogger()
|
||||
|
||||
# optional: log model topology
|
||||
logger.watch(model.net)
|
||||
|
||||
# ------------------------
|
||||
# 3 INIT TRAINER
|
||||
# ------------------------
|
||||
trainer = pl.Trainer(
|
||||
gpus=hparams.gpus
|
||||
gpus=hparams.gpus,
|
||||
logger=logger,
|
||||
max_epochs=hparams.epochs,
|
||||
accumulate_grad_batches=hparams.grad_batches,
|
||||
distributed_backend=hparams.distributed_backend,
|
||||
precision=16 if hparams.use_amp else 32,
|
||||
)
|
||||
|
||||
# ------------------------
|
||||
# 3 START TRAINING
|
||||
# 5 START TRAINING
|
||||
# ------------------------
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--root", type=str, help="path where dataset is stored")
|
||||
parser.add_argument("--gpus", type=int, help="number of available GPUs")
|
||||
parser.add_argument("--data_path", type=str, help="path where dataset is stored")
|
||||
parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs")
|
||||
parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
|
||||
help='supports three options dp, ddp, ddp2')
|
||||
parser.add_argument('--use_amp', action='store_true', help='if true uses 16 bit precision')
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
|
||||
parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate")
|
||||
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
|
||||
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
|
||||
parser.add_argument("--bilinear", action='store_true', default=False,
|
||||
help="whether to use bilinear interpolation or transposed")
|
||||
parser.add_argument("--grad_batches", type=int, default=1, help="number of batches to accumulate")
|
||||
parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train")
|
||||
parser.add_argument("--log_wandb", action='store_true', help="log training on Weights & Biases")
|
||||
|
||||
hparams = parser.parse_args()
|
||||
|
||||
|
|
|
@ -9,39 +9,46 @@ class UNet(nn.Module):
|
|||
Link - https://arxiv.org/abs/1505.04597
|
||||
|
||||
Parameters:
|
||||
num_classes (int): Number of output classes required (default 19 for KITTI dataset)
|
||||
bilinear (bool): Whether to use bilinear interpolation or transposed
|
||||
num_classes: Number of output classes required (default 19 for KITTI dataset)
|
||||
num_layers: Number of layers in each side of U-net
|
||||
features_start: Number of features in first layer
|
||||
bilinear: Whether to use bilinear interpolation or transposed
|
||||
convolutions for upsampling.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=19, bilinear=False):
|
||||
def __init__(
|
||||
self, num_classes: int = 19,
|
||||
num_layers: int = 5,
|
||||
features_start: int = 64,
|
||||
bilinear: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.layer1 = DoubleConv(3, 64)
|
||||
self.layer2 = Down(64, 128)
|
||||
self.layer3 = Down(128, 256)
|
||||
self.layer4 = Down(256, 512)
|
||||
self.layer5 = Down(512, 1024)
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.layer6 = Up(1024, 512, bilinear=bilinear)
|
||||
self.layer7 = Up(512, 256, bilinear=bilinear)
|
||||
self.layer8 = Up(256, 128, bilinear=bilinear)
|
||||
self.layer9 = Up(128, 64, bilinear=bilinear)
|
||||
layers = [DoubleConv(3, features_start)]
|
||||
|
||||
self.layer10 = nn.Conv2d(64, num_classes, kernel_size=1)
|
||||
feats = features_start
|
||||
for _ in range(num_layers - 1):
|
||||
layers.append(Down(feats, feats * 2))
|
||||
feats *= 2
|
||||
|
||||
for _ in range(num_layers - 1):
|
||||
layers.append(Up(feats, feats // 2), bilinear)
|
||||
feats //= 2
|
||||
|
||||
layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.layer1(x)
|
||||
x2 = self.layer2(x1)
|
||||
x3 = self.layer3(x2)
|
||||
x4 = self.layer4(x3)
|
||||
x5 = self.layer5(x4)
|
||||
|
||||
x6 = self.layer6(x5, x4)
|
||||
x6 = self.layer7(x6, x3)
|
||||
x6 = self.layer8(x6, x2)
|
||||
x6 = self.layer9(x6, x1)
|
||||
|
||||
return self.layer10(x6)
|
||||
xi = [self.layers[0](x)]
|
||||
# Down path
|
||||
for layer in self.layers[1:self.num_layers]:
|
||||
xi.append(layer(xi[-1]))
|
||||
# Up path
|
||||
for i, layer in enumerate(self.layers[self.num_layers:-1]):
|
||||
xi[-1] = layer(xi[-1], xi[-2 - i])
|
||||
return self.layers[-1](xi[-1])
|
||||
|
||||
|
||||
class DoubleConv(nn.Module):
|
||||
|
@ -50,7 +57,7 @@ class DoubleConv(nn.Module):
|
|||
(3x3 conv -> BN -> ReLU) ** 2
|
||||
"""
|
||||
|
||||
def __init__(self, in_ch, out_ch):
|
||||
def __init__(self, in_ch: int, out_ch: int):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
|
@ -70,7 +77,7 @@ class Down(nn.Module):
|
|||
Combination of MaxPool2d and DoubleConv in series
|
||||
"""
|
||||
|
||||
def __init__(self, in_ch, out_ch):
|
||||
def __init__(self, in_ch: int, out_ch: int):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
@ -88,7 +95,7 @@ class Up(nn.Module):
|
|||
followed by double 3x3 convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, in_ch, out_ch, bilinear=False):
|
||||
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
|
||||
super().__init__()
|
||||
self.upsample = None
|
||||
if bilinear:
|
||||
|
|
Loading…
Reference in New Issue